import math
import random
from typing import Sequence, Any, Tuple

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.envs.continuous_grid_world import posn_dist


def points_sep_by_dist(grid_size: float, dist: float) -> Tuple[Tuple[float, float], Tuple[float, float]]:
    """
    Find two random points separated by exactly given distance, both within the grid
    """
    while True:
        x1 = random.uniform(0, grid_size)
        y1 = random.uniform(0, grid_size)

        # Randomly choose an angle and find the second point
        angle = random.uniform(0, 2 * 3.14159)
        x2 = x1 + dist * math.cos(angle)
        y2 = y1 + dist * math.sin(angle)

        # Verify that the second point is on the grid
        if 0 <= x2 <= grid_size and 0 <= y2 <= grid_size:
            return (x1, y1), (x2, y2)


class NoisyRodMoving(MultiAgentSafetyEnv):

    def __init__(self, grid_size: float,
                 separating_distance: float,
                 separating_distance_tolerance: float,
                 observation_noise: float,
                 goal_tolerance: float,
                 completion_reward: float,
                 drop_reward: float,
                 step_reward: float):
        self.grid_size = grid_size
        self.separating_distance = separating_distance
        self.distance_tolerance = separating_distance_tolerance
        self.observation_noise = observation_noise
        self.goal_tolerance = goal_tolerance
        self.completion_reward = completion_reward
        self.drop_reward = drop_reward
        self.step_reward = step_reward

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        """
        Observation for both agents:
        - x position
        - y position
        - other agent x position + noise
        - other agent y position + noise
        - goal 1 x position (order is same for both agents)
        - goal 1 y position
        - goal 2 x position
        - goal 2 y position
        """
        return [spaces.Box(low=0, high=self.grid_size, shape=(8,))] * 2

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(5)] * 2

    def state_space(self) -> spaces.Space:
        return spaces.Box(low=0, high=self.grid_size, shape=(8,))

    def initial_state(self) -> Tuple[Any, Sequence[Any]]:
        agent1_start, agent2_start = points_sep_by_dist(self.grid_size, self.separating_distance)
        goal1, goal2 = points_sep_by_dist(self.grid_size, self.separating_distance)
        state = agent1_start + agent2_start + goal1 + goal2
        return state, self.get_all_obs(state)

    def move_agent(self, x, y, action):
        if action == 0:
            return x, y
        else:
            amount_moved = random.uniform(0, 1)
            dx, dy = [(0, -1), (1, 0), (0, 1), (-1, 0)][action - 1]
            dx *= amount_moved
            dy *= amount_moved

            new_x = x + dx
            new_y = y + dy

            # Bound to grid size
            # noinspection PyTypeChecker
            new_x = max(0, min(self.grid_size, new_x))
            # noinspection PyTypeChecker
            new_y = max(0, min(self.grid_size, new_y))
            return new_x, new_y

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        agent1_x, agent1_y, agent2_x, agent2_y, goal1_x, goal1_y, goal2_x, goal2_y = environment_state

        # Move agents
        agent1_x, agent1_y = self.move_agent(agent1_x, agent1_y, joint_action[0])
        agent2_x, agent2_y = self.move_agent(agent2_x, agent2_y, joint_action[1])

        # Check if agents are within tolerance of distance apart from each other
        distance_between_agents = posn_dist((agent1_x, agent1_y), (agent2_x, agent2_y))
        within_dist_tolerance = abs(distance_between_agents - self.separating_distance) < self.distance_tolerance

        agent_1_near_goal_1 = self.agent_near_goal(agent1_x, agent1_y, goal1_x, goal1_y)
        agent_1_near_goal_2 = self.agent_near_goal(agent1_x, agent1_y, goal2_x, goal2_y)
        agent_2_near_goal_1 = self.agent_near_goal(agent2_x, agent2_y, goal1_x, goal1_y)
        agent_2_near_goal_2 = self.agent_near_goal(agent2_x, agent2_y, goal2_x, goal2_y)

        goal_satisfied = (agent_1_near_goal_1 and agent_2_near_goal_2) or (agent_1_near_goal_2 and agent_2_near_goal_1)

        if not within_dist_tolerance:
            reward = self.drop_reward
            done = True
            safe = False
        elif goal_satisfied:
            reward = self.completion_reward
            done = True
            safe = True
        else:
            reward = self.step_reward
            done = False
            safe = True

        state = (agent1_x, agent1_y, agent2_x, agent2_y, goal1_x, goal1_y, goal2_x, goal2_y)

        return state, self.get_all_obs(state), [reward, reward], done, safe

    def agent_near_goal(self, agent_x, agent_y, goal_x, goal_y) -> bool:
        return posn_dist((agent_x, agent_y), (goal_x, goal_y)) < self.goal_tolerance

    def get_all_obs(self, state) -> Sequence[Any]:
        return self.get_observation(state, 0), self.get_observation(state, 1)

    def get_observation(self, state, agent_num):
        agent_x, agent_y, other_agent_x, other_agent_y, goal1_x, goal1_y, goal2_x, goal2_y = state

        if agent_num == 1:
            agent_x, agent_y, other_agent_x, other_agent_y = other_agent_x, other_agent_y, agent_x, agent_y

        other_agent_x += random.uniform(-self.observation_noise, self.observation_noise)
        other_agent_y += random.uniform(-self.observation_noise, self.observation_noise)
        return agent_x, agent_y, other_agent_x, other_agent_y, goal1_x, goal1_y, goal2_x, goal2_y
