import numpy as np
import copy
np.set_printoptions(precision=3)


class Policy:
    def __init__(self, n_traj, alpha):
        self.n_traj = n_traj
        self.alpha = alpha
        self.actions = np.array([i for i in range(4)])

        self.theta = np.zeros((8,8,4))
        self.v = np.zeros((8,8))

        self.traj_buf = []

    def put_data(self, traj):
        self.traj_buf.append(traj)

    def get_theta_prob(self, s):
        x, y = s
        theta = self.theta[x][y]
        exp_theta = np.exp(theta)
        prob_theta = exp_theta / np.sum(exp_theta)
        return prob_theta

    def get_action(self, s):
        prob = self.get_theta_prob(s)
        action = np.random.choice(self.actions, size=1, p=prob)[0]
        return action

    def eval_action(self, s):
        prob = self.get_theta_prob(s)
        action = np.argmax(prob)
        return action

    def der_theta_logpi(self, s, a):
        pi_theta = self.get_theta_prob(s)

        x,y = s
        der_theta = np.zeros((8,8,4))
        der_theta[x][y][a] = 1.
        der_theta[x][y] -= pi_theta

        return der_theta

    def reinforce_baseline(self, lrp, lrv, gamma):
        for ep_i in range(self.n_traj):
            traj = self.traj_buf[ep_i]
            transition = copy.deepcopy(traj)
            transition.reverse()

            traj_len = len(transition)
            ret = 0.0
            for t in range(traj_len):
                item = transition[t]
                ret = item[2] + gamma * ret

                s, a = item[0], item[1]
                x, y = s

                value = self.v[x][y]
                delta = ret - value
                der_theta = self.der_theta_logpi(s, a)

                self.v[x][y] += lrv * delta
                self.theta += (lrp / self.n_traj) * pow(gamma, traj_len-1-t) * der_theta * delta
        
        # clean buffer
        self.traj_buf = []


    
