import abc

from torch import nn


class StatefulModule(nn.Module, abc.ABC):
    """
    Modules that inherit from this class should take and return two parameters in their forward method:
    - input: the input to the module
    - state: the current state of the module
    """

    @abc.abstractmethod
    def get_initial_state(self, batch_size: int):
        """
        :param batch_size: The number of samples in the batch
        :return: The initial state of the module
        """
        pass
