import numpy as np
import torch.nn as nn


class LargerMLP(nn.Module):
    def __init__(self, obs_space, num_outputs: int, low, high):
        super().__init__()

        self.num_features = np.asarray(obs_space).prod()

        self.net = nn.Sequential(
            nn.Linear(self.num_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_outputs)
        )

        for idx, layer in enumerate(self.net):
            if isinstance(layer, nn.Linear):
                non_linearity = "relu" if idx < len(self.net) - 1 else "linear"
                nn.init.xavier_normal_(layer.weight, nn.init.calculate_gain(non_linearity))

    def forward(self, input):
        input_flat = input.view((len(input), self.num_features))
        return self.net(input_flat)
