import gym
import numpy as np

from centralized_verification.utils import TrainingProgress, convert_gym_space_to_q_shape


def default_epsilon_schedule(training_progress):
    return 0.1


def linear_epsilon_anneal_steps(start: float, end: float, num_steps: int):
    def eps_anneal_func(training_progress: TrainingProgress):
        lin_pos = float(training_progress.global_step_count) / num_steps
        lin_pos = (lin_pos if lin_pos >= 0 else 0) if lin_pos <= 1 else 1
        return (1 - lin_pos) * start + lin_pos * end

    return eps_anneal_func


def linear_epsilon_anneal_episodes(start: float, end: float, num_episodes: int):
    def eps_anneal_func(training_progress: TrainingProgress):
        lin_pos = float(training_progress.global_episode_count) / num_episodes
        lin_pos = (lin_pos if lin_pos >= 0 else 0) if lin_pos <= 1 else 1
        return (1 - lin_pos) * start + lin_pos * end

    return eps_anneal_func


def transform_potentially_multidiscrete_obs(obs, obs_space, make_multidiscrete_one_hot):
    if isinstance(obs_space, gym.spaces.MultiDiscrete):
        if make_multidiscrete_one_hot:
            arr = np.zeros(sum(obs_space.nvec))
            skip = 0
            for val, max_val in zip(obs, obs_space.nvec):
                arr[skip + val] = 1
                skip += max_val
            return arr
        else:
            return np.asarray(obs)
    else:
        return obs


def get_space_params(obs_space, action_space, make_multidiscrete_one_hot):
    num_actions = convert_gym_space_to_q_shape(action_space)[0]

    if isinstance(obs_space, gym.spaces.Box):
        low = obs_space.low
        high = obs_space.high
        obs_shape = obs_space.shape
        obs_dtype = obs_space.dtype
    elif isinstance(obs_space, gym.spaces.MultiDiscrete):
        if make_multidiscrete_one_hot:
            num_values = sum(obs_space.nvec)
            high = tuple(1 for _ in range(num_values))
        else:
            num_values = len(obs_space.nvec)
            high = tuple(obs_space.nvec)
        low = tuple(0 for _ in range(num_values))
        obs_shape = (num_values,)
        obs_dtype = float
    else:
        raise Exception("Learner must have a box or multidiscrete observation space")

    return obs_shape, num_actions, obs_dtype, low, high
