import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import os

import sys
sys.path.append('..')
from lunar_lander_risk import LunarLander
LunarLander_LEN = 500

import argparse
parser = argparse.ArgumentParser(description='seed gamma')
parser.add_argument('--seed', type=int, default=1, help='seed')
parser.add_argument('--gamma', type=float, default=0.999, help='gamma')
args = parser.parse_args()

# check and use GPU if available if not use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ActorNet(nn.Module):
    def __init__(self, state_size, action_size, hidden_size):
        super(ActorNet, self).__init__()
        self.l1 = nn.Linear(state_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, action_size)
    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        return F.softmax(self.output(x),dim=-1) #-1 to take softmax of last dimension
    
class ValueFunctionNet(nn.Module):
    def __init__(self, state_size, hidden_size):
        super(ValueFunctionNet, self).__init__()
        self.l1 = nn.Linear(state_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.output = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        return self.output(x)

class PGAgent():
    def __init__(self, state_size, action_size, hidden_size, actor_lr, vf_lr, discount, n_episodes):
        self.action_size = action_size
        self.actor_net = ActorNet(state_size, action_size, hidden_size).to(device)
        self.vf_net = ValueFunctionNet(state_size, hidden_size).to(device)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(), lr=actor_lr)
        self.vf_optimizer = optim.Adam(self.vf_net.parameters(), lr=vf_lr)
        self.discount = discount
        self.n_episodes = n_episodes

        self.state_buf, self.action_buf, self.reward_buf = [], [], []

    def save_best(self, save_path):
        torch.save(self.actor_net.state_dict(), save_path + 'actor_best.th')
        torch.save(self.vf_net.state_dict(), save_path + 'value_best.th')

    def save_model(self, save_path, ep):
        torch.save(self.actor_net.state_dict(), save_path + 'actor_ep_'+str(ep)+'.th')
        torch.save(self.vf_net.state_dict(), save_path + 'value_ep_'+str(ep)+'.th')
        
    def select_action(self, state):
        with torch.no_grad():
            #get action probs then randomly sample from the probabilities    
            action_probs = self.actor_net(state)
            #detach and turn to numpy to use with np.random.choice()
            probs = action_probs.detach().cpu().numpy()
            action = np.random.choice(np.arange(self.action_size), p=probs)
            
        return action

    def put_data(self, state_lst, action_lst, reward_lst):
        self.state_buf.append(state_lst)
        self.action_buf.append(action_lst)
        self.reward_buf.append(reward_lst)

    def train(self):
        ret_lst = []             # total return of each traj
        reinforce_loss_lst = []  # sum_t - log pi(a_t|s_t) A_t of each traj

        for i_e in range(self.n_episodes):
            state_lst = self.state_buf[i_e]
            action_lst = self.action_buf[i_e]
            reward_lst = self.reward_buf[i_e]
        
            #turn rewards into return
            trajectory_len = len(reward_lst)
            return_array = np.zeros((trajectory_len,))
            g_return = 0.
            for i in range(trajectory_len-1,-1,-1):
                g_return = reward_lst[i] + self.discount * g_return
                return_array[i] = g_return
            ret_lst.append(g_return)
            
            # create tensors
            state_t = torch.FloatTensor(np.array(state_lst)).to(device)     # [traj_len, 8]
            action_t = torch.LongTensor(action_lst).to(device).view(-1, 1)  # [traj_len, 1]
            return_t = torch.FloatTensor(return_array).to(device).view(-1,1)# [traj_len, 1]

            # get value function estimates
            vf_t = self.vf_net(state_t).to(device)
            with torch.no_grad():
                advantage_t = return_t - vf_t
        
            # calculate actor loss
            selected_action_prob = self.actor_net(state_t).gather(1, action_t) # shape [traj_len, 1]
            
            # REINFORCE loss:
            #actor_loss = torch.mean(-torch.log(selected_action_prob) * return_t)
            # REINFORCE Baseline loss:
            actor_loss = torch.sum(-torch.log(selected_action_prob) * advantage_t) # []
            reinforce_loss_lst.append(actor_loss)

            # calculate vf loss, update value in the inner loop
            loss_fn = nn.MSELoss()
            vf_loss = loss_fn(vf_t, return_t)
            self.vf_optimizer.zero_grad()
            vf_loss.backward()
            self.vf_optimizer.step()

        
        reinforce_loss_t = torch.stack(reinforce_loss_lst)   # shape [n_episodes]
        
        policy_loss = reinforce_loss_t.mean()

        self.actor_optimizer.zero_grad()
        policy_loss.backward()
        self.actor_optimizer.step()

        # clean data
        self.state_buf, self.action_buf, self.reward_buf = [], [], []


############# functional #############
def eval_model(env, agent, n_episodes=5):
    # used for check the final location
    # taken from lunarlander
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    helipad_y  = H/4
    #-----------------------

    ep_return = []
    land_left = 0

    for i in range(n_episodes):
        total_reward = 0
        episode_length = 0
        s = env.reset()
        while True:
            a = agent.select_action(torch.from_numpy(s).float().to(device))
            s, r, done, info = env.step(a)
            total_reward += r
            episode_length += 1

            # end episode early
            if episode_length == LunarLander_LEN:
                done = True
            if done:
                x = s[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
                #y = s[1] * (VIEWPORT_H/SCALE/2) + (helipad_y+LEG_DOWN/SCALE)
                #if (x>=7.5 and x<=10) and (y >= 3.7 and y<= 4.3) and r == 100:
                if x<=10 and r == 100:
                    land_left += 1
                break

        ep_return.append(total_reward)

    return np.array(ep_return), land_left

def play_episode(env, agent):
    ##########
    VIEWPORT_W = 600
    VIEWPORT_H = 400
    LEG_DOWN = 18
    SCALE = 30.0
    H = VIEWPORT_H/SCALE
    #########

    state = env.reset()
    state_list, action_list, reward_list = [], [], []
    episode_length = 0

    total_reward = 0.
    land_left = False
    while True:
        action = agent.select_action(torch.from_numpy(state).float().to(device))
        next_state, reward, done, _ = env.step(action)
        episode_length += 1
        total_reward += reward

        # store agent's trajectory
        state_list.append(state)
        action_list.append(action)
        reward_list.append(reward)

        # end episode early
        if episode_length == LunarLander_LEN:
            done = True

        if done:
            x = next_state[0] * (VIEWPORT_W/SCALE/2) + (VIEWPORT_W/SCALE/2)
            if x<=10 and reward == 100:
                land_left = True
            break

        state = next_state
    
    # store to buffer
    agent.put_data(state_list, action_list, reward_list)

    return total_reward, land_left

def get_save_dir(n_episodes, seed, discount, lr_policy, lr_value):
    save_dir = './save/bs_'+str(n_episodes) + '/gamma_'+str(discount) + '/'
    save_dir += '/lr_p=' + str(lr_policy) + '/lr_v=' + str(lr_value) + '/seed_' + str(seed) + '/'
    return save_dir

################# setting ##################
noise_scale = 100
env = LunarLander(noise_scale)
eval_env = LunarLander(noise_scale)
state_size = env.observation_space.shape[0]
action_size = env.action_space.n

# set seed
seed = args.seed
env.seed(seed)
eval_env.seed(2**31-1-seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# hyperparameters
epochs = 4000              # run agent for this many epochs
hidden_size = 128          # number of units in NN hidden layers
actor_lr = 0.0007          # learning rate for actor
value_function_lr = 0.007  # learning rate for value function
discount = args.gamma      # discount factor gamma value
n_episodes = 30
save_intvl = 50

eval_intvl = 20
save_model_intvl = 50
eval_episodes = 10

print('seed:', seed, 'discount:', discount)
print('lr_p:', actor_lr, 'lr_v:', value_function_lr, 'n_episodes:', n_episodes)

################ main ###############

# create agent
agent = PGAgent(state_size, action_size, hidden_size, actor_lr, value_function_lr, discount, n_episodes)

train_rewards_lst, train_land_left_lst = [], []

# eval_rewards_list = [] # store evaluation rewards
# eval_land_left_list = []
# best_eval_return = -10000

# create save dir
save_dir = get_save_dir(n_episodes, seed, discount, actor_lr, value_function_lr)
os.makedirs(save_dir, exist_ok=True)

for ep in range(epochs):
    epoch_ret_lst = []
    epoch_land_left = 0.
    
    for _ in range(n_episodes):
        ret, land = play_episode(env, agent)
        epoch_ret_lst.append(ret)
        if land:
            epoch_land_left += 1
    
    train_rewards_lst.append(np.mean(epoch_ret_lst))
    train_land_left_lst.append(epoch_land_left)
    #print('ret', np.mean(epoch_ret_lst), 'land', epoch_land_left)

    agent.train()


    if (ep+1) % save_intvl == 0:
        with open(save_dir + 'ret.npy', 'wb') as f:
            np.save(f, train_rewards_lst)
        with open(save_dir + 'land.npy', 'wb') as f:
            np.save(f, train_land_left_lst)

        


