from typing import Tuple, Dict, Callable

import torch

from centralized_verification.configuration import Configuration, TestConfiguration
from centralized_verification.train import train_loop, test_loop
from centralized_verification.training_state import maybe_load_from_checkpoint
from experiments.utils.configuration.utils import is_true


class ConfigRunner:
    def __init__(self, config_getter: Callable[[Dict[str, str]], Tuple[Configuration, TestConfiguration]]):
        self.config_getter = config_getter

    def __call__(self, params):
        torch.set_num_threads(4)
        torch.set_num_interop_threads(4)
        config, test_config = self.config_getter(params)

        checkpoint = maybe_load_from_checkpoint(config.run_name)

        if checkpoint:
            config.learner.load_state_dict(checkpoint.learner_state_dict)

        skip_training = is_true(params.get("skip_training", False))
        if not skip_training:
            train_loop(config, checkpoint)

        skip_evaluation = is_true(params.get("skip_evaluation", False))
        if not skip_evaluation:
            test_loop(test_config)
