import BAIalgos as algos
import pickle

import multiprocessing as mp

import argparse
from argparse import RawTextHelpFormatter

from problems import Problems
from modular import get_name

Algorithms = [algos.TrackAndStop,
              algos.CTrackAndStop,
              algos.ChernoffOverlap,
              algos.ChernoffRacing]

# # Parse commandline arguments
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser(description='Driver for comparing cost weighted BAI algorithms.', formatter_class=RawTextHelpFormatter)

help = 'Pick problem(s):'
for i, problem in enumerate(Problems):
    help += '\n{} = {}'.format(i, problem.get_name())

parser.add_argument('-p', '--problems', help=help, nargs='*',
                    type=int, default=[i for i in range(len(Problems))])

help = 'Pick algorithm(s):'
for i, algo in enumerate(Algorithms):
    help += '\n{} = {}'.format(i, get_name(algo))


parser.add_argument('-a', '--algos', help=help, nargs='*',
                    type=int, default=[i for i in range(len(Algorithms))])


def per_core(lock, algo, problem):
    data_obj = problem.get_data(algo)
    filename = problem.get_name()
    path = "./data/{}".format(filename)
    lock.acquire()
    with open(path, 'rb') as data_file:
        data = pickle.load(data_file)
    with open(path, 'wb') as data_file:
        data[data_obj.name] = data_obj
        pickle.dump(data, data_file)
    lock.release()


if __name__ == "__main__":
    args = parser.parse_args()

    Locks = {}
    for i in args.problems:
        problem_name = Problems[i].get_name()
        Locks[problem_name] = mp.Lock()

    procs = []
    for i in args.problems:
        problem = Problems[i]
        problem_name = problem.get_name()
        lock = Locks[problem_name]

        lock.acquire()
        path = "./data/{}".format(problem_name)
        with open(path, 'wb') as data_file:
            data = {}
            data["lower_bound"] = problem.lower_bound()
            pickle.dump(data, data_file)
        lock.release()

        for j in args.algos:
            algo = Algorithms[j]
            p = mp.Process(target=per_core,
                           args=(lock, algo, problem))
            procs.append(p)
            p.start()

    for p in procs:
        p.join()

    print("Done")
