use bit_set::BitSet;
use bit_vec::BitVec;
use fnv::FnvHashSet;
use itertools::Itertools;

pub type ActionNumber = u32;
pub type Decentralization = Vec<Vec<ActionNumber>>;

#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct JointActionSet(BitSet);

pub fn action_sizes_to_offsets(action_sizes: &[ActionNumber]) -> (Vec<ActionNumber>, ActionNumber) {
    let mut cumsum = 0u32;
    let mut outvec = vec![];
    for elem in action_sizes{
        outvec.push(cumsum);
        cumsum += *elem;
    }
    (outvec, cumsum)
}


impl JointActionSet {
    pub fn all_joint_actions(total_action_size: ActionNumber) -> Self{
        let bs = BitVec::from_elem(total_action_size as usize, true);
        Self(BitSet::from_bit_vec(bs))
    }

    pub fn includes_action(&self, joint_action: &[ActionNumber], action_offsets: &[ActionNumber]) -> bool{
        joint_action.iter().enumerate().all(|(idx, indiv_action)| self.0.contains((action_offsets[idx] + *indiv_action) as usize))
    }

    #[allow(dead_code)]
    pub fn includes_all_joint_actions<T : AsRef<[ActionNumber]>>(&self, action_offsets: &[ActionNumber], mut joint_actions: impl Iterator<Item=T>) -> bool {
        joint_actions.all(|ja| self.includes_action(ja.as_ref(), action_offsets))
    }

    pub fn is_subset_of(&self, other_joint_action_set: &JointActionSet) -> bool{
        self.0.is_subset(&other_joint_action_set.0)
    }

    #[allow(dead_code)]
    pub fn add_individual_action(&self, action_offsets: &[ActionNumber], agent: ActionNumber, action_num: ActionNumber) -> Option<JointActionSet> {
        let idx = (action_offsets[agent as usize] + action_num) as usize;
        if self.0.contains(idx) {
            None
        } else{
            let mut new_ja_bit_set = self.0.clone();
            new_ja_bit_set.insert(idx);
            Some(JointActionSet(new_ja_bit_set))
        }
    }

    pub fn remove_joint_action(&self, joint_action: &[ActionNumber], action_offsets: &[ActionNumber], action_sizes: &[ActionNumber]) -> Vec<JointActionSet>{
        if !self.includes_action(joint_action, action_offsets) {
            vec![self.clone()]
        } else {
            joint_action.iter().enumerate().filter_map(|(idx, indiv_action_to_remove)| {
                let mut ja_set_before = self.0.get_ref().clone();
                let mut ja_set_this_agent = ja_set_before.split_off(action_offsets[idx] as usize);
                let mut ja_set_after = ja_set_this_agent.split_off(action_sizes[idx] as usize);

                ja_set_this_agent.set(*indiv_action_to_remove as usize, false);

                if !ja_set_this_agent.any(){
                    None
                } else{
                    ja_set_before.append(&mut ja_set_this_agent);
                    ja_set_before.append(&mut ja_set_after);
                    Some(JointActionSet(BitSet::from_bit_vec(ja_set_before)))
                }
            }).collect()
        }
    }

    pub fn to_indiv_action_vecs(&self, action_offsets: &[ActionNumber], action_sizes: &[ActionNumber]) -> Vec<Vec<ActionNumber>> {
        action_offsets.iter().zip(action_sizes.iter())
            .map(|(action_offset, action_size)| {
                let mut this_agents_actions = vec![];
                for single_action in 0..(*action_size) {
                    if self.0.contains((action_offset + single_action) as usize){
                        this_agents_actions.push(single_action);
                    }
                }
                this_agents_actions
            }).collect()
    }
}

pub fn get_max_perm_action_combinations(allowed_actions: FnvHashSet<Vec<ActionNumber>>, action_space: impl AsRef<[ActionNumber]>) -> FnvHashSet<Decentralization>{
    let (action_offsets, total_indiv_actions) = action_sizes_to_offsets(action_space.as_ref());
    let mut all_max_perm_action_sets = FnvHashSet::default();
    all_max_perm_action_sets.insert(JointActionSet::all_joint_actions(total_indiv_actions));

    for joint_action in action_space.as_ref().iter().map(|n| (0 as ActionNumber)..*n).multi_cartesian_product(){
        if !allowed_actions.contains(&joint_action) {
            let mut new_max_perm_action_sets = FnvHashSet::default();
            for old_action_set in all_max_perm_action_sets.into_iter() {
                // If the old action set contains a forbidden action, add its derivatives to the new action set
                // Otherwise, add itself to the set
                for new_action_set in old_action_set.remove_joint_action(&joint_action, &action_offsets, action_space.as_ref()){
                    new_max_perm_action_sets.insert(new_action_set);
                }
            }
            all_max_perm_action_sets = new_max_perm_action_sets;
        }
    }

    // Remove subsumed sets
    let mut max_perm_action_sets_remove_subsumed = FnvHashSet::default();
    'outer: for candidate_set in all_max_perm_action_sets.into_iter() {
        let mut sets_we_subsume = vec![];

        'inner: for possible_conflict in max_perm_action_sets_remove_subsumed.iter(){
            if candidate_set.is_subset_of(possible_conflict){
                continue 'outer; // Subsumed by other set already here
            } else if possible_conflict.is_subset_of(&candidate_set) {
                // We subsume the other set
                sets_we_subsume.push(possible_conflict.clone());
                continue 'inner;
            }
        }

        for set_to_remove in sets_we_subsume.iter() {
            max_perm_action_sets_remove_subsumed.remove(set_to_remove);
        }

        max_perm_action_sets_remove_subsumed.insert(candidate_set);
    }

    max_perm_action_sets_remove_subsumed.iter().map(|set| set.to_indiv_action_vecs(&action_offsets, action_space.as_ref())).collect()
}


#[cfg(test)]
pub mod max_perm_actions_test {
    use fnv::FnvHashSet;

    use crate::decentralization::max_perm_actions::get_max_perm_action_combinations;

    #[test]
    pub fn test_max_perm_actions_1(){
        let joint_actions = vec![(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (2, 0), (2, 2)];
        let allowed_actions = FnvHashSet::from_iter(joint_actions.into_iter().map(|(a, b)| vec![a, b]));

        let action_sizes = vec![3, 3];

        let result = get_max_perm_action_combinations(allowed_actions, &action_sizes);

        let mut expected = FnvHashSet::default();
        expected.insert(vec![vec![0, 1, 2], vec![0]]);
        expected.insert(vec![vec![0, 1], vec![0, 1]]);
        expected.insert(vec![vec![0], vec![0, 1, 2]]);
        expected.insert(vec![vec![0, 2], vec![0, 2]]);


        assert_eq!(result, expected);
    }

    #[test]
    pub fn test_max_perm_actions_2(){
        let joint_actions = vec![(1, 0), (0, 2), (1, 3), (2, 1)];
        let allowed_actions = FnvHashSet::from_iter(joint_actions.into_iter().map(|(a, b)| vec![a, b]));

        let action_sizes = vec![3, 4];

        let result = get_max_perm_action_combinations(allowed_actions, &action_sizes);

        let mut expected = FnvHashSet::default();
        expected.insert(vec![vec![1], vec![0, 3]]);
        expected.insert(vec![vec![0], vec![2]]);
        expected.insert(vec![vec![2], vec![1]]);

        assert_eq!(result, expected);
    }

    #[test]
    pub fn test_max_perm_actions_3(){
        let joint_actions = vec![(0, 0), (0, 1), (0, 2), (1, 1), (1, 2)];
        let allowed_actions = FnvHashSet::from_iter(joint_actions.into_iter().map(|(a, b)| vec![a, b]));

        let action_sizes = vec![3, 3];

        let result = get_max_perm_action_combinations(allowed_actions, &action_sizes);

        let mut expected = FnvHashSet::default();
        expected.insert(vec![vec![0, 1], vec![1, 2]]);
        expected.insert(vec![vec![0], vec![0, 1, 2]]);

        assert_eq!(result, expected);
    }
}