import torch
import gym
from gym.envs.registration import register
import numpy as np
import os
import random
import sys
sys.path.append('..')
from policy import Policy

########## register env ###########
Pendulum_LEN = 300
register(
    id="IvPos-v0",
    entry_point="inverted_pendulum_pos:InvertedPendulumPosEnv",
    max_episode_steps=Pendulum_LEN,
    reward_threshold=None,
    nondeterministic=False,
)
#################################
from argparse import ArgumentParser
parser = ArgumentParser('parameters')
parser.add_argument("--env_name", type=str, default = 'IvPos-v0', help = "(default : HCPos-v0)")
parser.add_argument('--epochs', type=int, default=6000, help='number of epochs, (default: 3000)')
parser.add_argument('--alpha', type=float, default=0.2, help='CVaR alpha')
parser.add_argument('--lr_p', type=float, default=3e-4, help='policy learning rate')
parser.add_argument("--seed", type=int, default = 1, help='seed')
args = parser.parse_args()

######## setting #########
epochs = args.epochs
hidden_dim = 128
alpha = args.alpha
actor_lr = args.lr_p
discount = 0.999
n_episodes = 30
save_intvl = 30
seed = args.seed

# seeding
env = gym.make(args.env_name)
eval_env = gym.make(args.env_name)
action_dim = env.action_space.shape[0]
state_dim = env.observation_space.shape[0]
max_action = float(env.action_space.high[0])
print('state_dim', state_dim, 'action_dim', action_dim, 'max_action', max_action)
env.seed(seed)
eval_env.seed(2**31-1-seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

########## functional #########
def play_episodes(env, agent, n_episodes):
    max_episode_length = Pendulum_LEN

    return_lst = []
    ep_length_lst = []
    xpos_vio_lst = []

    final_pos_lst, max_pos_lst, min_pos_lst = [], [], []
    traj_visit = 0
    
    for _ in range(n_episodes):
        state_lst, action_lst, reward_lst, done_lst = [], [], [], []
        s, done = env.reset(), False
        ep_r, total_step, xpos_vio = 0, 0, 0
        pos_lst = []
        visit_noise = False
        for t in range(max_episode_length):
            with torch.no_grad():
                mu, sigma = agent.get_action(torch.from_numpy(s).float())
                dist = torch.distributions.Normal(mu, sigma[0])
                action = dist.sample()
                a = action.cpu().numpy()
            s_prime, r, done, info = env.step(a)

            ep_r += r
            total_step += 1
            xpos = info['x_position']
            pos_lst.append(xpos)
            if xpos > 0.04:
                xpos_vio += 1
                visit_noise = True

            state_lst.append(s)
            action_lst.append(a)
            reward_lst.append(r)
            done_lst.append(done) 

            if done:
                final_pos_lst.append(info['x_position'])
                max_pos_lst.append(np.max(pos_lst))
                min_pos_lst.append(np.min(pos_lst))
                if visit_noise:
                    traj_visit += 1
                break      

            s = s_prime

        # store to buffer
        agent.put_data(state_lst, action_lst, reward_lst, done_lst)

        return_lst.append(ep_r)
        ep_length_lst.append(total_step)
        xpos_vio_lst.append(xpos_vio)

    rate = np.array(xpos_vio_lst) / np.array(ep_length_lst)
    
    rate = rate.mean()
    return np.mean(return_lst), rate, traj_visit /n_episodes, final_pos_lst, max_pos_lst, min_pos_lst


#################################################

agent = Policy(state_dim, action_dim, hidden_dim, alpha, actor_lr, discount, n_episodes)

# record
train_return, train_vio_rate, train_traj_vio_rate = [], [], []
train_finalpos_lst, train_maxpos_lst, train_minpos_lst = [], [], []

def save(alpha, lr_policy, hidden_dim, seed):
    root = './save/alpha_' + str(alpha) +'/lr_p_' + str(lr_policy) 
    root += '/h_dim_' + str(hidden_dim) + '/seed_' + str(seed) + '/'
    
    os.makedirs(root, exist_ok=True)

    with open(root+'ret.npy', 'wb') as f1:
        np.save(f1, train_return)
    with open(root + 'rate.npy', 'wb') as f2:
        np.save(f2, train_vio_rate)
    with open(root + 'traj_rate.npy', 'wb') as f:
        np.save(f, train_traj_vio_rate)
    with open(root + 'final_pos.npy', 'wb') as f:
        np.save(f, train_finalpos_lst)
    with open(root + 'max_pos.npy', 'wb') as f:
        np.save(f, train_maxpos_lst)
    with open(root + 'min_pos.npy', 'wb') as f:
        np.save(f, train_minpos_lst)

# main
for ep_i in range(args.epochs):
    train_ret, train_vio, traj_vio, fpos, maxpos, minpos =play_episodes(env, agent, n_episodes)
    train_return.append(train_ret)
    train_vio_rate.append(train_vio)
    train_traj_vio_rate.append(traj_vio)
    train_finalpos_lst.append(fpos)
    train_maxpos_lst.append(maxpos)
    train_minpos_lst.append(minpos)
    
    agent.update_policy()

    if (ep_i +1) % save_intvl == 0:
        save(alpha, actor_lr, hidden_dim, seed)


