
from fitted_algo import FittedAlgo
import numpy as np
from tqdm import tqdm
from env_nn import *
from thread_safe import threadsafe_generator
from keras import backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

class LakeFittedQIteration(FittedAlgo):
    def __init__(self, num_inputs, grid_shape, dim_of_actions, max_epochs, gamma, model_type='mlp', position_of_goals=None, position_of_holes=None, num_frame_stack=None):
        '''
        An implementation of fitted Q iteration

        num_inputs: number of inputs
        dim_of_actions: dimension of action space
        max_epochs: positive int, specifies how many iterations to run the algorithm
        gamma: discount factor
        '''
        self.model_type = model_type
        self.num_inputs = num_inputs
        self.grid_shape= grid_shape
        self.dim_of_actions = dim_of_actions
        self.max_epochs = max_epochs
        self.gamma = gamma
        self.position_of_goals = position_of_goals
        self.position_of_holes = position_of_holes
        self.num_frame_stack = num_frame_stack

        super(LakeFittedQIteration, self).__init__()


    def run(self, dataset, epochs=3000, epsilon=1e-8, desc='FQI', **kw):
        # dataset is the original dataset generated by pi_{old} to which we will find
        # an approximately optimal Q

        self.Q_k = self.init_Q(model_type=self.model_type, position_of_holes=self.position_of_holes, position_of_goals=self.position_of_goals, num_frame_stack=self.num_frame_stack, **kw)

        X_a = np.hstack(dataset.get_state_action_pairs())
        x_prime = dataset['x_prime']

        index_of_skim = self.skim(X_a, x_prime)
        X_a = X_a[index_of_skim]
        x_prime = x_prime[index_of_skim]
        dataset_costs = dataset['cost'][index_of_skim]
        terminateds = dataset['terminated'][index_of_skim]
        truncateds = dataset['truncated'][index_of_skim]
        
        for k in tqdm(range(self.max_epochs), desc=desc):
            
            # {((x,a), c+gamma*min_a Q(x',a))}
            costs = dataset_costs + self.gamma*self.Q_k.min_over_a(x_prime)[0]*(1 - (truncateds | terminateds).astype(int))

            self.fit(X_a, costs, epochs=epochs, batch_size=X_a.shape[0], epsilon=epsilon, evaluate=False, verbose=0)
            # import pdb; pdb.set_trace()

            # if not self.Q_k.callbacks_list[0].converged:
            #     print 'Continuing training due to lack of convergence'
            #     self.fit(X_a, costs, epochs=epochs, batch_size=X_a.shape[0], epsilon=epsilon, evaluate=False, verbose=0)

        return self.Q_k, []

    def init_Q(self, epsilon=1e-10, **kw):
        return LakeNN(self.num_inputs, 1, self.grid_shape, self.dim_of_actions, self.gamma, convergence_of_model_epsilon=epsilon, **kw)


class CarFittedQIteration(FittedAlgo):
    def __init__(self, state_space_dim, 
                       dim_of_actions, 
                       max_epochs, 
                       gamma, 
                       model_type='cnn', 
                       num_frame_stack=None,
                       initialization=None,
                       freeze_cnn_layers=False,
                       br_gamma = 0.95):
        '''
        An implementation of fitted Q iteration

        num_inputs: number of inputs
        dim_of_actions: dimension of action space
        max_epochs: positive int, specifies how many iterations to run the algorithm
        gamma: discount factor
        '''
        self.initialization = initialization
        self.freeze_cnn_layers = freeze_cnn_layers
        self.model_type = model_type
        self.state_space_dim = state_space_dim
        self.dim_of_actions = dim_of_actions
        self.max_epochs = max_epochs
        self.gamma = gamma
        self.br_gamma = br_gamma
        self.num_frame_stack = num_frame_stack
        self.Q_k = None
        self.Q_k_minus_1 = None

        earlyStopping = EarlyStopping(monitor='val_loss', min_delta=1e-4,  patience=10, verbose=1, mode='min', restore_best_weights=True)
        mcp_save = ModelCheckpoint('fqi.keras', save_best_only=True, monitor='val_loss', mode='min')
        reduce_lr_loss = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=7, verbose=1, min_delta=1e-4, mode='min')

        self.more_callbacks = [earlyStopping, mcp_save, reduce_lr_loss]

        super(CarFittedQIteration, self).__init__()


    def run(self, dataset, epochs=1, epsilon=1e-8, desc='FQI', exact=None, **kw):
        # dataset is the original dataset generated by pi_{old} to which we will find
        # an approximately optimal Q

        # if self.Q_k is None:
        self.Q_k = self.init_Q(model_type=self.model_type, num_frame_stack=self.num_frame_stack, **kw)
        self.Q_k_minus_1 = self.init_Q(model_type=self.model_type, num_frame_stack=self.num_frame_stack, **kw)
        x_prime = np.rollaxis(dataset['frames'][dataset['next_states'][[0]]], 1,4)
        self.Q_k.max_over_a([x_prime], x_preprocessed=True)[0]
        self.Q_k_minus_1.max_over_a([x_prime], x_preprocessed=True)[0]
        self.Q_k.copy_over_to(self.Q_k_minus_1)
        values = []

        # exact.run(self.Q_k,to_monitor=False)
        print(f'epochs = {epochs}')

        for k in tqdm(range(self.max_epochs), desc=desc):
            batch_size = 64
            
            dataset_length = len(dataset)
            perm = np.random.permutation(range(dataset_length))
            eighty_percent_of_set = int(1.*len(perm))
            training_idxs = perm[:eighty_percent_of_set]
            validation_idxs = perm[eighty_percent_of_set:]
            training_steps_per_epoch = int(np.ceil(len(training_idxs)/float(batch_size)))
            validation_steps_per_epoch = int(np.ceil(len(validation_idxs)/float(batch_size)))
            # steps_per_epoch = 1 #int(np.ceil(len(dataset)/float(batch_size)))
            train_gen = MyDataset(self.generator(dataset, training_idxs, fixed_permutation=True, batch_size=batch_size), len(training_idxs), batch_size, max_queue_size = 10, workers = 4)
            #val_gen = MyDataset(self.generator(dataset, validation_idxs, fixed_permutation=True, batch_size=batch_size))

            if (k >= (self.max_epochs-10)): self.Q_k.model.optimizer.learning_rate = 0.0001
            #print("!", training_steps_per_epoch * epochs)
            self.fit_generator(train_gen, 
                               steps_per_epoch=training_steps_per_epoch,
                               #validation_data=val_gen, 
                               #validation_steps=validation_steps_per_epoch,
                               epochs=epochs, 
                               epsilon=epsilon, 
                               evaluate=False, 
                               verbose=0,
                               additional_callbacks = self.more_callbacks)
            #print("!")
            self.Q_k.copy_over_to(self.Q_k_minus_1)
            if k >= (self.max_epochs-10):
                c,g,perf = exact.run(self.Q_k,to_monitor=k==self.max_epochs)
                #c,g,perf = exact.run(self.Q_k,to_monitor=False)
                values.append([c,perf])
                
        return self.Q_k, values

    @threadsafe_generator
    def generator(self, dataset, training_idxs, fixed_permutation=False,  batch_size = 64):
        data_length = len(training_idxs) 
        steps = int(np.ceil(data_length/float(batch_size))) 
        i = -1
        amount_of_data_calcd = 0
        if fixed_permutation:
            calcd_costs = np.empty((len(training_idxs),), dtype='float64')
        while True:
            #print(i)
            i = (i + 1) % steps
            # print 'Getting batch: %s to %s' % ((i*batch_size),((i+1)*batch_size))
            if fixed_permutation:
                if i == 0: perm = np.random.permutation(training_idxs)
                batch_idxs = perm[(i*batch_size):((i+1)*batch_size)]
            else:
                batch_idxs = np.random.choice(training_idxs, batch_size)
            # amount_of_data_calcd += len(batch_idxs)
            # import pdb; pdb.set_trace()  
            # print(i, dataset['prev_states'].shape)
            X = np.rollaxis(dataset['frames'][dataset['prev_states'][batch_idxs]],1,4)
            actions = np.atleast_2d(dataset['a'][batch_idxs]).T
            x_prime = np.rollaxis(dataset['frames'][dataset['next_states'][batch_idxs]],1,4)
            dataset_costs = dataset['cost'][batch_idxs]

            terminateds = dataset['terminated'][batch_idxs]
            truncateds = dataset['truncated'][batch_idxs]

            # if fixed_permutation:
            #     if amount_of_data_calcd <= data_length:
            #         costs = dataset_costs + self.gamma*self.Q_k_minus_1.min_over_a([x_prime], x_preprocessed=True)[0]*(1-dones.astype(int))
            #         calcd_costs[(i*batch_size):((i+1)*batch_size)] = costs
            #     else:
            #         costs = calcd_costs[(i*batch_size):((i+1)*batch_size)]
            # else:
            costs = dataset_costs + self.br_gamma*self.Q_k_minus_1.max_over_a([x_prime], x_preprocessed=True)[0]*(1-(truncateds | terminateds).astype(int))

            X = self.Q_k_minus_1.representation([X], actions, x_preprocessed=True)

            yield (X, costs)

    def init_Q(self, epsilon=1e-10, **kw):
        model = CarNN(self.state_space_dim, self.dim_of_actions, self.gamma, convergence_of_model_epsilon=epsilon, freeze_cnn_layers=self.freeze_cnn_layers, **kw)
        return model
        if (self.initialization is not None) and self.freeze_cnn_layers:
            self.initialization.Q.copy_over_to(model)
            for layer in model.model.layers:
                if layer.trainable: 
                    try:
                        layer.kernel.initializer.run( session = K.get_session() )
                    except:
                        pass
                    try:
                        layer.bias.initializer.run( session = K.get_session() )
                    except:
                        pass
        return model
