import csv
import importlib
import json
import os.path
import sys
from collections import defaultdict
from csv import DictReader
from typing import Dict, Tuple, List, Any, FrozenSet, Union, Sequence

import tqdm

from centralized_verification.shields.combine_identical_states import OutgoingActionShieldState, \
    iterate_shield_cleanup
from centralized_verification.shields.partial_obs.asym_obs_shields import parse_shield_labels, \
    to_partial_observations_shield, partial_obs_centralized_shield_to_json


def to_nice_type(s: str):
    try:
        return int(s)
    except ValueError:
        if s == "TRUE":
            return True
        if s == "FALSE":
            return False
        return s


def load_outgoing_action_shield_from_csv(reader: DictReader, processor, raw_obs_labels, raw_action_labels,
                                         hidden_labels, initial_cond) -> Dict[
    int, OutgoingActionShieldState]:
    # State -> [Action -> List[State]]
    action_map: Dict[int, Dict[Tuple[int, ...], List[int]]] = defaultdict(lambda: defaultdict(lambda: []))

    # State -> (Observations, Hidden, Initial)
    state_infos: Dict[
        int, Tuple[FrozenSet[Tuple[Union[int, bool, str], ...]], Tuple[Union[int, bool, str], ...], bool]] = {}

    unsafe_state_nums = set()

    for state_info in tqdm.tqdm(reader, desc="Loading transitions"):
        state_info = {k: to_nice_type(v) for k, v in state_info.items()}
        state_id = state_info["CurrID"]

        if state_info["violation"]:
            unsafe_state_nums.add(state_id)
            continue

        if state_id not in state_infos:
            processor_ret: Union[Dict[str, Any], Sequence[Dict[str, Any]]] = processor(state_info)
            if isinstance(processor_ret, dict):
                processor_ret = (processor_ret,)

            obs_labels = []

            for state_info_update in processor_ret:
                this_state_info = state_info.copy()
                this_state_info.update(state_info_update)

                obs_labels.append(tuple(this_state_info[obs_label] for obs_label in raw_obs_labels))

            obs_labels = frozenset(obs_labels)

            if state_id not in state_infos:
                state_infos[state_id] = (
                    obs_labels,
                    tuple(state_info[hidden_label] for hidden_label in hidden_labels),
                    initial_cond(state_info)
                )

        this_action = tuple(int(state_info[action_label]) for action_label in raw_action_labels)
        action_map[state_id][this_action].append(int(state_info["NextID"]))

    ret: Dict[int, OutgoingActionShieldState] = {}
    for state_num, (obs_label, hidden_label, initial) in tqdm.tqdm(state_infos.items(), desc="Loading states"):
        filtered_actions = {}
        for action, to_states in action_map[state_num].items():
            if not any(ts in unsafe_state_nums for ts in to_states):
                filtered_actions[action] = frozenset(to_states)

        ret[state_num] = OutgoingActionShieldState(
            label=obs_label,
            initial_state=initial,
            actions=filtered_actions,
            hidden_values=tuple(map(lambda i: {i}, hidden_label)),
            state_num=state_num
        )

    return ret


def load_smv_shield_dec_obs(name):
    file = open(name + ".csv", "r")
    file_reader = csv.DictReader(file, dialect="excel-tab")
    label_names = set(file_reader.fieldnames)
    label_names.remove("CurrID")
    label_names.remove("NextID")

    # Check if post processor exists
    if os.path.isfile(name + ".py"):
        print("Loading post processor")
        proc_module = importlib.import_module(name.replace("/", "."))
        processor = proc_module.process
        extra_labels = proc_module.extra_labels
        initial_cond = proc_module.initial_cond
    else:
        print("No post processor found")
        processor = lambda x: {}
        extra_labels = []
        initial_cond = lambda x: True

    label_names.update(extra_labels)
    raw_obs_labels, obs_labels, raw_action_labels, hidden_labels = parse_shield_labels(label_names)
    hidden_labels = [hl for hl in hidden_labels if not hl.startswith("next(")]
    outgoing_action_states = load_outgoing_action_shield_from_csv(file_reader, processor, raw_obs_labels,
                                                                  raw_action_labels, hidden_labels, initial_cond)
    cleaned_shield = iterate_shield_cleanup(outgoing_action_states)
    partial_obs_shield = to_partial_observations_shield(cleaned_shield, len(raw_action_labels), obs_labels,
                                                        hidden_labels)

    return partial_obs_shield


if __name__ == '__main__':
    name = sys.argv[1]

    shield = load_smv_shield_dec_obs(name)
    output_json = partial_obs_centralized_shield_to_json(shield)

    print("Dumping shield as JSON")
    with open(name + ".shield_pobs_cent", "w") as file:
        json.dump(output_json, file, indent=1)
