import torch
import torch.nn as nn


class TabularGradient(nn.Module):
    def __init__(self, obs_space, num_outputs, low, high):
        super().__init__()
        dim = (*tuple(h - l for h, l in zip(high, low)), num_outputs)
        self.table = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, input):
        return self.table[tuple(map(tuple, input.T.long().numpy()))]
