from typing import Callable, Any, List, Dict

from centralized_verification.envs.fast_grid_world import FastGridWorld
from centralized_verification.envs.fast_grid_world_left_obs import FastGridWorldLeftObs
from centralized_verification.envs.flashlight_grid_world import FlashlightGridWorld
from centralized_verification.envs.particle_momentum import ParticleMomentum
from centralized_verification.shields.partial_obs.pobs_label_shield_spec import label_value_type

PobsLabelExtractor = Callable[[Any, List[Any]], List[Dict[str, label_value_type]]]


def flashlight_gridworld_extractor(env: FlashlightGridWorld) -> PobsLabelExtractor:
    def extractor_one_agent(observation, agent_num) -> Dict[str, int]:
        this_agent_info, flashlight, other_agent_info = observation
        agent_x, agent_y = env.grid_posns[this_agent_info]

        if other_agent_info == len(env.grid_posns):
            other_x, other_y = "UNK", "UNK"
        else:
            other_x, other_y = env.grid_posns[other_agent_info]

        other_agent_num = 1 - agent_num

        return {
            "pos_x": agent_x,
            "pos_y": agent_y,
            "light": bool(flashlight),
            f"a{other_agent_num}_pos_x": other_x,
            f"a{other_agent_num}_pos_y": other_y
        }

    def extractor(state, observations):
        return [extractor_one_agent(observations[n], n) for n in range(2)]

    return extractor


def fast_gridworld_partialobs_extractor_2agent(env: FastGridWorldLeftObs) -> PobsLabelExtractor:
    def extractor_one_agent(observation, agent_num) -> Dict[str, int]:
        this_agent_info, other_agent_info = observation
        agent_x, agent_y = env.grid_posns[this_agent_info]

        if other_agent_info == len(env.grid_posns):
            other_x, other_y = "UNK", "UNK"
        else:
            other_x, other_y = env.grid_posns[other_agent_info]

        other_agent_num = 1 - agent_num

        return {
            "pos_x": agent_x,
            "pos_y": agent_y,
            f"a{other_agent_num}_pos_x": other_x,
            f"a{other_agent_num}_pos_y": other_y
        }


    def extractor(state, observations):
        return [extractor_one_agent(observations[n], n) for n in range(2)]

    return extractor


def fast_gridworld_fullobs_extractor_2agent(env: FastGridWorld) -> PobsLabelExtractor:
    def extractor(state, observations):
        a1, a2 = state
        (a1x, a1y), (a2x, a2y) = env.grid_posns[a1], env.grid_posns[a2]

        info1 = {
            "pos_x": a1x,
            "pos_y": a1y
        }

        info2 = {
            "pos_x": a2x,
            "pos_y": a2y
        }

        info = {
            "rel_pos_x": a2x - a1x,
            "rel_pos_y": a2y - a1y,
        }

        info1.update(info)
        info2.update(info)

        return [info1, info2]

    return extractor

def particle_momentum_extractor(env: ParticleMomentum, include_rel_momentum: bool):
    def extractor(state, observations):
        rel_x, rel_y, rel_vx, rel_vy = state

        info = {
            "rel_pos_x": rel_x - (env.world_size - 1),
            "rel_pos_y": rel_y - (env.world_size - 1),
        }

        if include_rel_momentum:
            info.update({
                "rel_momentum_x": rel_vx - 2,
                "rel_momentum_y": rel_vy - 2,
            })

        return [info, info]

    return extractor
