from utils import ReplayBuffer, convert_to_tensor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim, activation_function, last_activation, trainable_std):
        super(Actor, self).__init__()
        self.trainable_std = trainable_std
        
        self.l1 = nn.Linear(state_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, action_dim)
        self.activation = activation_function
        self.last_activation = last_activation

        if self.trainable_std:
            self.logstd = nn.Parameter(torch.zeros(1, action_dim))
    def forward(self, x):
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        mu = self.out(x)
        if self.last_activation is not None:
            mu = self.last_activation(mu)

        if self.trainable_std:
            std = torch.exp(self.logstd)
        else:
            logstd = torch.zeros_like(mu)
            std = torch.exp(logstd)
        return mu, std

class Critic(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, activation_function):
        super(Critic, self).__init__()
        self.l1 = nn.Linear(input_dim, hidden_dim)
        self.l2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
        self.activation = activation_function

    def forward(self, *x):
        x = torch.cat(x,-1)
        x = self.activation(self.l1(x))
        x = self.activation(self.l2(x))
        x = self.out(x)
        return x

class SAC(nn.Module):
    def __init__(self, writer, device, state_dim, action_dim, args):
        super(SAC,self).__init__()
        self.args = args
        self.actor = Actor(state_dim, action_dim, self.args.hidden_dim, \
                           self.args.activation_function, self.args.last_activation, self.args.trainable_std)

        self.q_1 = Critic(state_dim+action_dim, 1, self.args.hidden_dim, self.args.activation_function)
        self.q_2 = Critic(state_dim+action_dim, 1, self.args.hidden_dim, self.args.activation_function)
        
        self.target_q_1 = Critic(state_dim+action_dim, 1, self.args.hidden_dim, self.args.activation_function)
        self.target_q_2 = Critic(state_dim+action_dim, 1, self.args.hidden_dim, self.args.activation_function)
        
        self.soft_update(self.q_1, self.target_q_1, 1.)
        self.soft_update(self.q_2, self.target_q_2, 1.)
        
        self.alpha = nn.Parameter(torch.tensor(self.args.alpha_init))
        
        self.data = ReplayBuffer(action_prob_exist = False, max_size = int(self.args.memory_size), state_dim = state_dim, num_action = action_dim)
        self.target_entropy = - torch.tensor(action_dim)

        self.q_1_optimizer = optim.Adam(self.q_1.parameters(), lr=self.args.q_lr)
        self.q_2_optimizer = optim.Adam(self.q_2.parameters(), lr=self.args.q_lr)
        
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.args.actor_lr)
        self.alpha_optimizer = optim.Adam([self.alpha], lr=self.args.alpha_lr)
        
        self.device = device
        self.writer = writer
        
    def put_data(self,transition):
        self.data.put_data(transition)
        
    def soft_update(self, network, target_network, rate):
        for network_params, target_network_params in zip(network.parameters(), target_network.parameters()):
            target_network_params.data.copy_(target_network_params.data * (1.0 - rate) + network_params.data * rate)
    

    def get_action(self,state):
        # just use a normal distribution
        mu,std = self.actor(state)
        dist = Normal(mu, std)
        a = dist.rsample()
        a_log_prob = dist.log_prob(a)
        return a, a_log_prob.sum(-1, keepdim=True)
    
    def q_update(self, Q, q_optimizer, states, actions, rewards, next_states, dones):
        ###target
        with torch.no_grad():
            next_actions, next_action_log_prob = self.get_action(next_states)
            q_1 = self.target_q_1(next_states, next_actions)
            q_2 = self.target_q_2(next_states, next_actions)
            q = torch.min(q_1,q_2)
            v = (1 - dones) * (q - self.alpha * next_action_log_prob)
            targets = rewards + self.args.gamma * v
        
        q = Q(states, actions)
        loss = F.smooth_l1_loss(q, targets)
        q_optimizer.zero_grad()
        loss.backward()
        q_optimizer.step()
        return loss
    
    def actor_update(self, states):
        now_actions, now_action_log_prob = self.get_action(states)
        q_1 = self.q_1(states, now_actions)
        q_2 = self.q_2(states, now_actions)
        q = torch.min(q_1, q_2)
        
        loss = (self.alpha.detach() * now_action_log_prob - q).mean()
        self.actor_optimizer.zero_grad()
        loss.backward()
        self.actor_optimizer.step()
        return loss,now_action_log_prob
    
    def alpha_update(self, now_action_log_prob):
        loss = (- self.alpha * (now_action_log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()    
        loss.backward()
        self.alpha_optimizer.step()
        return loss
    
    def train_net(self, batch_size, n_epi):
        data = self.data.sample(shuffle = True, batch_size = batch_size)
        states, actions, rewards, next_states, dones = convert_to_tensor(self.device, data['state'], data['action'], data['reward'], data['next_state'], data['done'])

        ###q update
        q_1_loss = self.q_update(self.q_1, self.q_1_optimizer, states, actions, rewards, next_states, dones)
        q_2_loss = self.q_update(self.q_2, self.q_2_optimizer, states, actions, rewards, next_states, dones)

        ### actor update
        actor_loss,prob = self.actor_update(states)
        
        ###alpha update
        alpha_loss = self.alpha_update(prob)
        
        self.soft_update(self.q_1, self.target_q_1, self.args.soft_update_rate)
        self.soft_update(self.q_2, self.target_q_2, self.args.soft_update_rate)
        if self.writer != None:
            self.writer.add_scalar("loss/q_1", q_1_loss, n_epi)
            self.writer.add_scalar("loss/q_2", q_2_loss, n_epi)
            self.writer.add_scalar("loss/actor", actor_loss, n_epi)
            self.writer.add_scalar("loss/alpha", alpha_loss, n_epi)
            
