from centralized_verification.envs.continuous_grid_world import ContinuousGridWorld
from centralized_verification.envs.continuous_rod_moving import NoisyRodMoving
from centralized_verification.envs.fast_grid_world import FastGridWorld
from centralized_verification.envs.fast_grid_world_2d_obs import FastGridWorld2DObs
from centralized_verification.envs.fast_grid_world_left_obs import FastGridWorldLeftObs
from centralized_verification.envs.fast_grid_world_nearby_obs import FastGridWorldNearbyObs
from centralized_verification.envs.fast_grid_world_nearby_obs_simple import FastGridWorldNearbyObsSimple
from centralized_verification.envs.fast_grid_world_partial_obs import FastGridWorldPartialObs
from centralized_verification.envs.flashlight_grid_world import FlashlightGridWorld
from centralized_verification.envs.particle_momentum import ParticleMomentum
from centralized_verification.envs.utils import map_parser
from experiments.utils.configuration.utils import is_true


def continuous_gridworld_extra_params(params):
    return {
        "no_idle": is_true(params.get("grid_world_no_idle", False)),
        "reward_shaping": is_true(params.get("grid_world_reward_shaping", False)),
        "reward_shaping_factor": float(params.get("grid_world_reward_shaping_factor", 1.0))
    }


def gridworld_2d_obs_extra_params(params):
    if "grid_world_obs_radius" in params:
        return {"other_agent_obs_radius": int(params["grid_world_obs_radius"])}
    else:
        return {}


def gridworld_nearby_obs_extra_params(params):
    return {
        "obs_radius": int(params.get("grid_world_nearby_obs_radius", 2)),
    }


def gridworld_flashlight_obs_extra_params(params):
    return {
        "obs_radius": int(params.get("grid_world_flashlight_obs_off_radius", 1)),
        "flashlight_obs_radius": int(params.get("grid_world_flashlight_obs_on_radius", 5)),
        "flashlight_recharge_time": int(params.get("grid_world_flashlight_obs_recharge_time", 4)),
    }


gridworld_types = {
    "PartialObsDiscrete": (FastGridWorldPartialObs, None),
    "FullObsDiscrete": (FastGridWorld, None),
    "NearbyObsDiscrete": (FastGridWorldNearbyObs, None),
    "2DObs": (FastGridWorld2DObs, gridworld_2d_obs_extra_params),
    "LeftObsDiscrete": (FastGridWorldLeftObs, None),
    "NearbyObsSimpleDiscrete": (FastGridWorldNearbyObsSimple, gridworld_nearby_obs_extra_params),
    "FlashlightObsDiscrete": (FlashlightGridWorld, gridworld_flashlight_obs_extra_params),
}


def get_basic_gridworld_configuration_vars(params):
    map_name = params["grid_world_map_name"]
    env_spec = map_parser(f"maps/{map_name}.txt")
    collision_reward = float(params.get("grid_world_collision_penalty", -10))
    agents_bounce = is_true(params.get("grid_world_agents_bounce", False))
    terminate_on_collision = is_true(params.get("grid_world_terminate_on_collision", False))
    randomize_starts = is_true(params.get("randomize_starts", False))
    idling_is_unsafe = is_true(params.get("grid_world_idling_is_unsafe", False))

    env_params = {
        "randomize_starts": randomize_starts,
        "collision_reward": collision_reward,
        "agents_bounce": agents_bounce,
        "terminate_on_collision": terminate_on_collision,
        "idling_is_unsafe": idling_is_unsafe
    }

    return env_spec, env_params


def construct_gridworld(params):
    env_spec, env_params = get_basic_gridworld_configuration_vars(params)
    gridworld_type, gridworld_extra_params = gridworld_types[params["grid_world_obs_type"]]
    if gridworld_extra_params:
        env_params.update(gridworld_extra_params(params))

    return gridworld_type(*env_spec, **env_params)


def construct_continuous_gridworld(params):
    env_spec, env_params = get_basic_gridworld_configuration_vars(params)
    no_idle = is_true(params.get("grid_world_no_idle", False))
    env_params["no_idle"] = no_idle

    env_params["reward_shaping"] = is_true(params.get("grid_world_reward_shaping", False))
    env_params["reward_shaping_factor"] = float(params.get("grid_world_reward_shaping_factor", 1.0))
    return ContinuousGridWorld(*env_spec, **env_params)


def construct_particle_momentum(params):
    world_size = int(params["particle_world_size"])
    collision_reward = float(params["particle_collision_penalty"])
    terminate_on_collision = is_true(params["particle_terminate_on_collision"])
    agents_observe_momentum = is_true(params["particle_agents_observe_momentum"])
    randomize_starts = is_true(params.get("randomize_starts", False))

    env_params = {
        "world_size": world_size,
        "agents_observe_momentum": agents_observe_momentum,
        "randomize_starts": randomize_starts,
        "collision_reward": collision_reward,
        "terminate_on_collision": terminate_on_collision
    }

    return ParticleMomentum(**env_params)


def construct_rod_moving(params):
    grid_size = float(params["rod_moving_grid_size"])
    separating_distance = float(params["rod_moving_separating_distance"])
    distance_tolerance = float(params["rod_moving_distance_tolerance"])
    goal_tolerance = float(params["rod_moving_goal_tolerance"])
    completion_reward = float(params["rod_moving_completion_reward"])
    drop_reward = float(params["rod_moving_drop_reward"])
    step_reward = float(params["rod_moving_step_reward"])

    env_params = {
        "grid_size": grid_size,
        "separating_distance": separating_distance,
        "distance_tolerance": distance_tolerance,
        "goal_tolerance": goal_tolerance,
        "completion_reward": completion_reward,
        "drop_reward": drop_reward,
        "step_reward": step_reward
    }

    return NoisyRodMoving(**env_params)


map_types = {
    "GridWorld": construct_gridworld,
    "ContinuousGridWorld": construct_continuous_gridworld,
    "ParticleMomentum": construct_particle_momentum,
    "RodMoving": construct_rod_moving,
}


def construct_env(params):
    return map_types[params["map_type"]](params)
