from __future__ import annotations

import enum
from dataclasses import dataclass

import numpy as np
import torch

from conf_tmaze import Config as conf


class Action(enum.Enum):
    UP = 0
    RIGHT = 1
    DOWN = 2
    LEFT = 3

    def to_tensor(self) -> torch.Tensor:
        return torch.tensor([[[1.0 if self.value == i else 0.0 for i in range(len(Action))]]])

    @classmethod
    def initial_tensor(cls) -> torch.Tensor:
        return -torch.ones((1, 1, 4))


class Observation(enum.Enum):
    START_UP = 0b1000 #0b011  # indicates the goal is UP
    START_DOWN = 0b0100 #0b110  # indicates the goal is DOWN
    CORRIDOR = 0b0010 #0b101
    T_JUNCTION = 0b0001 #0b010

    def to_tensor(self) -> torch.Tensor:
        return torch.tensor([[[float(x) for x in f"{self.value:04b}"]]])

    def get_action(self) -> Action:
        """Get corresponding action"""
        if self == Observation.START_UP:
            return Action.UP
        elif self == Observation.START_DOWN:
            return Action.DOWN
        else:
            raise ValueError(f"Action is not defined for this observation: {self}")

    @classmethod
    def num_bit(cls):
        return 4


class Reward(enum.Enum):
    ZERO = 0.0
    NEGATIVE = -0.1
    GOAL = 4.0


@dataclass(frozen=True, order=True)
class TMazeState:
    position: int = None
    reward: Reward = Reward.ZERO
    on_tjunc: bool = False
    episode_end: bool = False
    obs: Observation = None
    timestep: int = 0


class TMazeEnv:
    def __init__(self, init_obs: Observation):
        self.corridor_length = conf.CORRIDOR_LENGTH
        self.max_timestep = conf.MAX_TIMESTEP
        self.init_obs = init_obs

    @property
    def last_action(self) -> Action:
        return self.init_obs.get_action()

    def step(self, state: TMazeState, action: Action) -> TMazeState:
        if conf.STATE_ACTION:
            # 状態行動ノードを使う場合はここで行動ごとの発生確率にしたがって再サンプル
            prob = np.array([1.0 if action.value == i else 0.0 for i in range(len(Action))])
            action = np.random.choice(list(Action), p=prob)

        new_timestep = state.timestep + 1
        # step and return reward according to current position
        if state.on_tjunc:  # last step, where's goal??
            reward = Reward.GOAL if action == self.last_action else Reward.NEGATIVE
            return TMazeState(
                conf.INITIAL_POSITION, reward, True, True, Observation.T_JUNCTION, new_timestep
            )
        elif self._timeout(new_timestep):
            return TMazeState(
                conf.INITIAL_POSITION, Reward.NEGATIVE, False, True, Observation.CORRIDOR, new_timestep
            )
        else:
            assert 0 <= state.position < self.corridor_length
            # step towards action direction
            if action == Action.RIGHT:
                next_position = state.position + 1
            elif action == Action.LEFT:
                next_position = state.position - 1
                # if position is negative...
                if next_position < 0:
                    # stands still
                    next_position = 0
                    return TMazeState(
                        next_position, Reward.NEGATIVE, False, False, Observation.CORRIDOR, new_timestep
                    )
            else:
                # stands still and return negative reward
                return TMazeState(
                    state.position, Reward.NEGATIVE, False, False, Observation.CORRIDOR, new_timestep
                )

            # check position after step and make decision
            if self._on_corridor(next_position):
                return TMazeState(
                    next_position, Reward.ZERO, False, False, Observation.CORRIDOR, new_timestep
                )  # moved forward
            elif self._just_arrived_tjunc(state, next_position):
                return TMazeState(
                    next_position, Reward.ZERO, True, False, Observation.T_JUNCTION, new_timestep
                )  # moved forward
            else:
                raise ValueError("Reached to unknown state")

    def _on_corridor(self, pos: int) -> bool:
        return 0 <= pos < self.corridor_length

    def _just_arrived_tjunc(self, state: TMazeState, pos: int) -> bool:
        return pos == self.corridor_length and not state.on_tjunc

    def _timeout(self, timestep: int) -> bool:
        return self.max_timestep < timestep
