import numpy as np
import torch
import torch.nn as nn
from torch.nn import GRU

from centralized_verification.models.stateful_module import StatefulModule


class SimpleGRU(StatefulModule):
    def __init__(self, obs_space, num_outputs: int, low, high):
        super().__init__()

        self.num_features = np.asarray(obs_space).prod()

        self.net1 = nn.Sequential(
            nn.Linear(self.num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )

        self.gru = GRU(128, 128, batch_first=True)

        self.net2 = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_outputs)
        )

    def forward(self, input, hidden):
        batch_size, sequence_len, _ = input.shape
        input_flat = input.view((batch_size, sequence_len, self.num_features))
        output = self.net1(input_flat)
        output, hidden = self.gru(output, hidden)
        output = self.net2(output)
        return output, hidden

    def get_initial_state(self, batch_size: int):
        torch.zeros((1, batch_size, 128))
