import json
from typing import NamedTuple, Type, Callable, List, Dict, Any

from centralized_verification.shields.no_shield import NoShield
from centralized_verification.shields.partial_obs.asym_obs_shields import partial_obs_centralized_shield_from_json
from centralized_verification.shields.partial_obs.pobs_label_extractor import \
    fast_gridworld_partialobs_extractor_2agent, \
    fast_gridworld_fullobs_extractor_2agent, particle_momentum_extractor, flashlight_gridworld_extractor
from centralized_verification.shields.partial_obs.pobs_label_shield import PobsLabelShield
from centralized_verification.shields.partial_obs.pobs_label_shield_centralized import PobsLabelShieldCentralized
from centralized_verification.shields.partial_obs.pobs_label_shield_spec import pobs_label_shield_spec_from_json
from centralized_verification.shields.wrappers.action_transformer_wrapper import ActionTransformerShieldWrapper, \
    FlashlightActionTransformer
from experiments.utils.configuration.utils import is_true, update_all_func


def construct_partial_obs_label_extractor(params, env):
    return fast_gridworld_partialobs_extractor_2agent(env)


def construct_full_obs_label_extractor(params, env):
    return fast_gridworld_fullobs_extractor_2agent(env)


def construct_flashlight_obs_label_extractor(params, env):
    return flashlight_gridworld_extractor(env)


gridworld_label_extractors = {
    "LeftObsDiscrete": construct_partial_obs_label_extractor,
    "FullObsDiscrete": construct_full_obs_label_extractor,
    "NearbyObsSimpleDiscrete": construct_partial_obs_label_extractor,
    "FlashlightObsDiscrete": construct_flashlight_obs_label_extractor,
}


def construct_gridworld_label_extractor(params, env):
    return gridworld_label_extractors[params["grid_world_obs_type"]](params, env)


def construct_particle_momentum_label_extractor(params, env):
    return particle_momentum_extractor(env, is_true(params["particle_agents_observe_momentum"]))


label_extractor_constructors = {
    "GridWorld": construct_gridworld_label_extractor,
    "ParticleMomentum": construct_particle_momentum_label_extractor,
}


def load_label_extractor(params, env):
    return {
        "label_extractor": label_extractor_constructors[params["map_type"]](params, env)
    }


def load_shield_punish_params(params, env):
    punish_unsafe_orig_action = is_true(params.get("punish_unsafe_orig_action", False))
    if punish_unsafe_orig_action:
        return {
            "punish_unsafe_orig_action": True,
            "unsafe_action_punishment": float(params.get("punish_unsafe_orig_action_modifier"))
        }
    else:
        return {
            "punish_unsafe_orig_action": False
        }


def load_shield_maker(load_from_json_func, extension):
    def maker(param_name):
        def ret(params, env):
            shield_location = params[param_name]
            with open(f"{shield_location}.{extension}", "r") as file:
                shield_dict = json.load(file)
            return {
                "shield_spec": load_from_json_func(shield_dict)
            }

        return ret

    return maker


load_pobs_label_shield_spec = load_shield_maker(pobs_label_shield_spec_from_json, "shield_pobs_dec")
load_pobs_label_cent_shield_spec = load_shield_maker(partial_obs_centralized_shield_from_json, "shield_pobs_cent")


def grid_world_shield_wrapper(shield, params):
    if params["grid_world_obs_type"] == "FlashlightObsDiscrete":
        return ActionTransformerShieldWrapper(shield, FlashlightActionTransformer())
    else:
        return shield


wrappers = {
    "GridWorld": grid_world_shield_wrapper
}


def wrap_shield(shield, params):
    if params["map_type"] in wrappers:
        return wrappers[params["map_type"]](shield, params)
    else:
        return shield


class ShieldInfo(NamedTuple):
    shield_class: Type
    always_load: List[Callable[[Dict[str, Any], Any], Any]] = []
    load_execution: List[Callable[[Dict[str, Any], Any], Any]] = []
    load_evaluation: List[Callable[[Dict[str, Any], Any], Any]] = []


shields = {
    "none": ShieldInfo(NoShield, always_load=[load_shield_punish_params]),
    "pobs_label": ShieldInfo(PobsLabelShield, always_load=[load_shield_punish_params, load_label_extractor],
                             load_execution=[load_pobs_label_shield_spec("shield_specification")],
                             load_evaluation=[load_pobs_label_shield_spec("evaluation_shield_specification")]),
    "pobs_label_cent": ShieldInfo(PobsLabelShieldCentralized,
                                  always_load=[load_shield_punish_params, load_label_extractor],
                                  load_execution=[load_pobs_label_cent_shield_spec("shield_specification")],
                                  load_evaluation=[load_pobs_label_cent_shield_spec("evaluation_shield_specification")])
}


def construct_shield_v2(params, env):
    shield_info: ShieldInfo = shields[params["shield"]]
    shield_params = update_all_func(shield_info.always_load, [params, env])
    shield_params.update(update_all_func(shield_info.load_execution, [params, env]))
    shield = shield_info.shield_class(env=env, **shield_params)
    shield = wrap_shield(shield, params)

    if "evaluation_shield" in params:
        eval_shield_info: ShieldInfo = shields[params["evaluation_shield"]]
        eval_shield_params = update_all_func(eval_shield_info.always_load, [params, env])
        eval_shield_params.update(update_all_func(eval_shield_info.load_evaluation, [params, env]))
        eval_shield = eval_shield_info.shield_class(env=env, **eval_shield_params)
        eval_shield = wrap_shield(eval_shield, params)
    else:
        eval_shield = shield

    return shield, eval_shield
