/*
SAT Variables:

For each agent a:
  For each label l:
    For each individual action x:
      Variable shield_safe_a_l_x // determines whether the action is enabled in the decentralized shield

    Constraint shield_safe_a_l_0 \/ shield_safe_a_l_1 \/ ...

For each state s:
  For each possible decentralization pd:

    Variable poss_decent_s_pd // determines whether a given decentralization of state s is used

    For each agent a:
      Let l = label for state s as seen by agent a
      For each individual action x:
        If action x is disabled for agent a in possible decentralization pd for state s:
            Constraint poss_decent_s_pd -> ! shield_safe_a_l_x  // ! poss_decent_s_pd \/ ! shield_safe_a_l_x
        // If action is enabled, don't add a constraint.
        // IMPORTANT: shield_a_l_x will only true because it is "forced" to be true by the constraint that at least one is true
        // However, there might be additional shield_a_l_x's which are false, but aren't forced to be this way
        // Use poss_decent_s_pd to find the decentralizations for a given label outside of the SAT solver
        // TODO check don't care variables

  Constraint poss_decent_s_0 \/ poss_decent_s_1 \/ ...
  // Hopefully only one of these will be true

 */


use std::iter;
use std::ops::Not;

use fnv::{FnvHashMap, FnvHashSet};
use itertools::Itertools;
use kissat::{Solution, Solver, Var};

use crate::decentralization::max_perm_actions::{Decentralization, get_max_perm_action_combinations};
use crate::shields::{AgentLabelHistory, AgentLabelHistorySet, PartialObsHistCentralizedShield};

pub type AgentLabelTo<T> = Vec<FnvHashMap<AgentLabelHistory, T>>;

pub fn agent_label_map<T, U, F>(input: &AgentLabelTo<T>, mut mapper: F) -> AgentLabelTo<U>
where F : FnMut(usize, &AgentLabelHistory, &T) -> U {
    input.iter().enumerate().map(|(index, agent_info)| {
        agent_info.iter().map(|(label_history, val)| {
            (label_history.clone(), mapper(index, label_history, val))
        }).collect()
    }).collect()
}

pub type AgentLabelActionTo<T> = AgentLabelTo<Vec<T>>;
pub fn agent_label_action_map<T, U, F>(input: &AgentLabelActionTo<T>, mut mapper: F) -> AgentLabelActionTo<U>
    where F : FnMut(usize, &AgentLabelHistory, usize, &T) -> U {
    agent_label_map(input, |agent_num, label_history, actions| {
        actions.iter().enumerate().map(|(action_num, val)|{
            mapper(agent_num, label_history, action_num, val)
        }).collect()
    })
}

pub type StateTo<T> = FnvHashMap<u32, T>;

pub fn get_labels_to_state_nums(shield: &PartialObsHistCentralizedShield) -> AgentLabelTo<Vec<u32>> {
    /*
    Structure: labels_to_state_nums[agent][label] = [state_num0, ...]
    For all labels, get all states with that label
     */
    let mut labels_to_state_nums: AgentLabelTo<Vec<u32>> = vec![FnvHashMap::default(); shield.action_space.len()];

    for (state_num, state) in shield.shield_states.iter() {
        for (agent_num, label_set) in state.observations.iter().enumerate() {
            for label in label_set {
                labels_to_state_nums[agent_num].entry(label.clone()).or_default().push(*state_num)
            }
        }
    }

    labels_to_state_nums
}

pub fn make_n_variables(solver: &mut Solver, n: usize) -> Vec<Var> {
    iter::from_fn(|| {Some(solver.var())}).take(n).collect()
}

pub fn get_possible_decentralizations(shield: &PartialObsHistCentralizedShield) -> StateTo<Vec<Decentralization>> {
    /*
    For each state in the shield, find all safe decentralizations for that state
     */
    shield.shield_states.iter().map(|(state_num, state)|{
        (*state_num, Vec::from_iter(get_max_perm_action_combinations(state.actions.clone().into_iter().collect(), &shield.action_space)))
    }).collect()
}

pub fn find_decentralized_shield_sat(centralized_shield: &PartialObsHistCentralizedShield) -> Option<AgentLabelActionTo<bool>> {
    /*
    Entry point to SAT-based decentralized shielding, outputs a boolean map of allowed actions for each agent/label history combo
     */
    let labels_to_state_nums = get_labels_to_state_nums(centralized_shield);

    let mut solver = Solver::new();


    // labels_to_sat_var_shield_safe[agent][label] = [shield_safe_agent_label_0, shield_safe_agent_label_1, ...]
    let labels_to_sat_var_shield_safe = create_agent_label_action_vars_and_constraints(centralized_shield, &labels_to_state_nums, &mut solver);

    let decentralizations = get_possible_decentralizations(centralized_shield);

    // state_to_sat_var_poss_decent[state][n] = poss_decent_state_n
    let _state_to_sat_var_poss_decent = create_possible_decentralization_vars_and_constraints(&centralized_shield, &mut solver, &labels_to_sat_var_shield_safe, &decentralizations);

    solver.sat().map(|solution|{
        let mut dec_shield: AgentLabelActionTo<bool> = sat_solution_to_shield_naive(&solution, &labels_to_sat_var_shield_safe);
        expand_available_actions(&mut dec_shield, &centralized_shield, &labels_to_state_nums);
        dec_shield
    })
}

pub fn sat_solution_to_shield_naive(solution: &Solution, labels_to_sat_var_shield_safe: &AgentLabelActionTo<Var>) -> AgentLabelActionTo<bool> {
    /*
    Just get the output from the SAT solver, which may be overly restrictive- it may deny some actions which can easily be enabled
     */
    agent_label_action_map(labels_to_sat_var_shield_safe, |_agent_num, _label, _action_num, action_var|{
        solution.get(*action_var).map_or(true, |x| x)
    })
}

pub fn expand_available_actions(dec_shield: &mut AgentLabelActionTo<bool>, centralized_shield: &PartialObsHistCentralizedShield, labels_to_state_nums: &AgentLabelTo<Vec<u32>>){
    // Sometimes the SAT solver gives a safe but non-maximally-permissive solution.
    // Using the solution as a starting point, we incrementally test and add safe individual actions
    for agent_idx in 0..centralized_shield.action_space.len() {
        for (agent_label, state_nums_with_agent_label) in labels_to_state_nums[agent_idx].iter() {
            let new_actions = safe_actions_other_agents_constant(agent_idx, dec_shield, centralized_shield, state_nums_with_agent_label);
            dec_shield[agent_idx].insert(agent_label.clone(), new_actions);
        }
    }
}

fn safe_actions_other_agents_constant(agent_idx: usize, dec_shield: &AgentLabelActionTo<bool>, centralized_shield: &PartialObsHistCentralizedShield, relevant_state_nums: &[u32]) -> Vec<bool>{
    /*
    Given the current actions allowed by the decentralized shield for all labels, and the ground truth centralized shield,
    what is the intersection of actions which are safe for agent_idx to take in each state of relevant_state_nums
     */
    let mut safe_indiv_actions = vec![true; centralized_shield.action_space[agent_idx] as usize];

    for state_num in relevant_state_nums {
        let state = centralized_shield.shield_states.get(state_num).unwrap();

        // All actions that could be taken, assuming we use the actions in safe_indiv_actions
        let mut all_agent_indiv_actions = label_sets_to_indiv_actions(&state.observations, dec_shield);
        all_agent_indiv_actions[agent_idx] = safe_indiv_actions.clone();

        let possible_joint_actions_given_current_safety: FnvHashSet<Vec<u32>> = indiv_action_bitmap_to_joint_indices(&all_agent_indiv_actions).into_iter().collect();
        let safe_joint_actions_at_this_state : FnvHashSet<Vec<u32>> = state.actions.clone().into_iter().collect();
        for unsafe_joint_action in possible_joint_actions_given_current_safety.difference(&safe_joint_actions_at_this_state){
            let indiv_action_of_agent = unsafe_joint_action[agent_idx];
            safe_indiv_actions[indiv_action_of_agent as usize] = false;
        }
    }

    safe_indiv_actions
}

fn indiv_action_bitmap_to_joint_indices(indiv_actions: &Vec<Vec<bool>>) -> Vec<Vec<u32>> {
    let indiv_action_indices_list : Vec<Vec<u32>> = indiv_actions.iter().map(|single_agent_indiv_actions| {
        single_agent_indiv_actions.iter().enumerate().filter_map(|(idx, keep)| if *keep { Some(idx as u32) } else { None }).collect()
    }).collect();

    indiv_action_indices_list.into_iter().multi_cartesian_product().collect()
}

fn single_agent_label_to_indiv_action(agent_num: usize, label: &AgentLabelHistory, dec_shield: &AgentLabelActionTo<bool>) -> Vec<bool> {
    dec_shield[agent_num][label].clone()
}

fn agent_label_set_to_safe_indiv_actions(agent_num: usize, label_set: &AgentLabelHistorySet, dec_shield: &AgentLabelActionTo<bool>) -> Vec<bool> {
    // The individual actions for each label in label_set
    let all_indiv_actions : Vec<Vec<bool>> = label_set.into_iter().map(|label| {single_agent_label_to_indiv_action(agent_num, label, dec_shield)}).collect();

    // What actions are true for all labels
    (0..all_indiv_actions[0].len()).into_iter().map(|action_num| {
        all_indiv_actions.iter().all(|this_labels_actions| {
            this_labels_actions[action_num]
        })
    }).collect()
}

fn label_sets_to_indiv_actions(label_sets: &Vec<AgentLabelHistorySet>, dec_shield: &AgentLabelActionTo<bool>) -> Vec<Vec<bool>> {
    label_sets.iter().enumerate().map(|(agent_num, label_set)| agent_label_set_to_safe_indiv_actions(agent_num, label_set, dec_shield)).collect()
}

fn create_possible_decentralization_vars_and_constraints(shield: &PartialObsHistCentralizedShield, solver: &mut Solver, labels_to_sat_var_shield_safe: &AgentLabelActionTo<Var>, decentralizations: &StateTo<Vec<Decentralization>>) -> StateTo<Vec<Var>> {
    let mut state_to_sat_var_poss_decent: FnvHashMap<u32, Vec<Var>> = FnvHashMap::default();
    let mut num_decomp_selected_vars = 0u64;
    let mut num_obs_to_state_constraints = 0u64;

    for (state_num, state_info) in shield.shield_states.iter() {
        let mut state_poss_decentralization_vars = vec![];

        // println!("State {}", state_num);

        for decentralization in decentralizations[state_num].iter() {
            // println!("{:?}", decentralization);
            let this_decentralization_used = solver.var();
            num_decomp_selected_vars += 1;

            for (agent_num, (agent_allowed_actions, num_actions_for_agent)) in decentralization.iter().zip(&shield.action_space).enumerate() {
                for potential_action in 0..(*num_actions_for_agent) {
                    if !agent_allowed_actions.contains(&potential_action) {
                        // Constraint poss_decent_s_pd -> ! shield_safe_a_l_x
                        // Same as ! poss_decent_s_pd \/ ! shield_safe_a_l_x
                        let this_label_set = &state_info.observations[agent_num];
                        for this_label in this_label_set {
                            solver.add2(this_decentralization_used.not(), labels_to_sat_var_shield_safe[agent_num][this_label][potential_action as usize].not());
                            num_obs_to_state_constraints += 1;
                        }
                    }
                }
            }

            state_poss_decentralization_vars.push(this_decentralization_used);
        }

        // Constraint poss_decent_s_0 \/ poss_decent_s_1 \/ ...
        if let Some(xor_constraint) = state_poss_decentralization_vars.clone().into_iter().reduce(|a, b| solver.xor(a, b)) {
            solver.add1(xor_constraint);
        }
        state_to_sat_var_poss_decent.insert(*state_num, state_poss_decentralization_vars);
    }

    println!("Decomp vars: {}, Total states: {}", num_decomp_selected_vars, shield.shield_states.len());
    println!("Obs-to-state constraints: {}", num_obs_to_state_constraints);
    state_to_sat_var_poss_decent
}

fn create_agent_label_action_vars_and_constraints(shield: &PartialObsHistCentralizedShield, labels_to_state_nums: &AgentLabelTo<Vec<u32>>, mut solver: &mut Solver) -> AgentLabelActionTo<Var> {
    let mut labels_to_sat_var_shield_safe: AgentLabelActionTo<Var> = vec![FnvHashMap::default(); shield.action_space.len()];

    let mut num_action_enabled_vars = 0u64;
    for (agent_num, (label_map, num_actions_for_agent)) in labels_to_state_nums.iter().zip(shield.action_space.iter()).enumerate() {
        for (label, _list_of_states) in label_map.iter() {
            // Variable shield_safe_agent_num_label_[0..num_actions_for_agent]
            let action_vars = make_n_variables(&mut solver, *num_actions_for_agent as usize);

            // Constraint shield_safe_agent_num_label_0 \/ shield_safe_agent_num_label_1 \/ ...
            solver.add(&action_vars);
            num_action_enabled_vars += *num_actions_for_agent as u64;

            labels_to_sat_var_shield_safe[agent_num].insert(label.clone(), action_vars);
        }
    }
    println!("Num action_selected vars: {}", num_action_enabled_vars);
    labels_to_sat_var_shield_safe
}

#[cfg(test)]
pub mod sat_shield_test {
    use fnv::FnvHashMap;

    use crate::decentralization::sat_shield::{AgentLabelActionTo, indiv_action_bitmap_to_joint_indices, safe_actions_other_agents_constant};
    use crate::shields::{PartialObsHistCentralizedShield, PartialObsHistCentralizedShieldState};
    use crate::shields::ShieldVarValue::Int;

    #[test]
    pub fn test_indiv_action_to_joint_idx(){
        assert_eq!(
            indiv_action_bitmap_to_joint_indices(
                &vec![
                    vec![true, false, true],
                    vec![false, false, true, false],
                    vec![true, true]
                ]
            ),
            vec![
                vec![0, 2, 0],
                vec![0, 2, 1],
                vec![2, 2, 0],
                vec![2, 2, 1]
            ]
        )
    }

    #[test]
    pub fn test_get_safe_actions_other_constant(){
        let mut shield_states = FnvHashMap::default();
        shield_states.insert(0, PartialObsHistCentralizedShieldState{
            actions: vec![
                vec![0, 0], vec![1, 0], vec![0, 2], vec![2, 0], vec![2, 2]
            ],
            observations: vec![vec![vec![Int(0)]], vec![vec![Int(0)]]],
            initial: false,
            hidden: vec![],
        });

        shield_states.insert(1, PartialObsHistCentralizedShieldState{
            actions: vec![
                vec![0, 0], vec![1, 0], vec![2, 0]
            ],
            observations: vec![vec![vec![Int(0)]], vec![vec![Int(1)]]],
            initial: false,
            hidden: vec![]
        });

        let cshield = PartialObsHistCentralizedShield {
            action_space: vec![3, 3],
            obs_names: vec![],
            shield_states,
            hidden_obs_names: vec![],
            history_len: 0,
        };

        let dec_shield1: AgentLabelActionTo<bool> = vec![
            {
                let mut hmap = FnvHashMap::default();
                hmap.insert(vec![vec![Int(0)]], vec![true, false, false]);
                hmap
            },
            {
                let mut hmap = FnvHashMap::default();
                hmap.insert(vec![vec![Int(0)]], vec![true, false, true]);
                hmap.insert(vec![vec![Int(1)]], vec![true, false, false]);
                hmap
            }
        ];

        assert_eq!(safe_actions_other_agents_constant(0, &dec_shield1, &cshield, &vec![0, 1]), vec![true, false, true]);

        let dec_shield2: AgentLabelActionTo<bool> = vec![
            {
                let mut hmap = FnvHashMap::default();
                hmap.insert(vec![vec![Int(0)]], vec![true, false, false]);
                hmap
            },
            {
                let mut hmap = FnvHashMap::default();
                hmap.insert(vec![vec![Int(0)]], vec![true, false, false]);
                hmap.insert(vec![vec![Int(1)]], vec![true, false, false]);
                hmap
            }
        ];

        assert_eq!(safe_actions_other_agents_constant(0, &dec_shield2, &cshield, &vec![0, 1]), vec![true, true, true])
    }
}