from typing import Sequence

import gym.spaces as spaces

from centralized_verification.envs.fast_grid_world import FastGridWorld


def posn_dist_manhattan(posn1, posn2):
    return abs(posn1[0] - posn2[0]) + abs(posn1[1] - posn2[1])


class FastGridWorldNearbyObsSimple(FastGridWorld):

    def __init__(self, *args, obs_radius=2, **kwargs):
        self.obs_radius = obs_radius
        super().__init__(*args, **kwargs)

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        single_obs = [len(self.grid_posns)] + [len(self.grid_posns) + 1] * (self.num_agents - 1)
        return [spaces.MultiDiscrete(single_obs)] * self.num_agents

    def project_obs_agent(self, state, agent_num):
        agent_place_num = state[agent_num]
        agent_posn = self.grid_posns[agent_place_num]
        obs = [agent_place_num]
        for i, other_agent_place in enumerate(state):
            if i == agent_num:
                continue
            else:
                other_agent_posn = self.grid_posns[other_agent_place]
                if posn_dist_manhattan(agent_posn, other_agent_posn) > self.obs_radius:  # Unable to see other agent
                    obs.append(len(self.grid_posns))
                else:
                    obs.append(other_agent_place)

        return tuple(obs)

    def project_obs(self, state):
        return [self.project_obs_agent(state, i) for i in range(self.num_agents)]
