from typing import Sequence

import gym.spaces as spaces

from centralized_verification.envs.fast_grid_world import FastGridWorld


class FastGridWorldLeftObs(FastGridWorld):
    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        single_obs = [len(self.grid_posns)] + [len(self.grid_posns) + 1] * (self.num_agents - 1)
        return [spaces.MultiDiscrete(single_obs)] * self.num_agents

    def project_obs_agent(self, state, agent_num):
        agent_place_num = state[agent_num]
        (agent_x, _) = self.grid_posns[agent_place_num]
        obs = [agent_place_num]
        for i, other_agent_place in enumerate(state):
            if i == agent_num:
                continue
            else:
                (other_agent_x, _) = self.grid_posns[other_agent_place]
                if agent_x > other_agent_x:  # Unable to see other agent
                    obs.append(len(self.grid_posns))
                else:
                    obs.append(other_agent_place)

        return tuple(obs)

    def project_obs(self, state):
        return [self.project_obs_agent(state, i) for i in range(self.num_agents)]
