from torch import nn
import torch as th
from IPython import embed as ipshell

class AutoAdapt(nn.Module):
  def __init__(
      self, shape, impl, scale, target, min, max,
      vel=0.1, thres=0.1, inverse=False, device='cuda'):
    super(AutoAdapt, self).__init__()
    self._shape = shape
    self._impl = impl
    self._target = target
    self._min = min
    self._max = max
    self._vel = vel
    self._inverse = inverse
    self._thres = thres
    if self._impl == 'mult':
      if len(shape) == 0:
        self._scale = 1.0
      else:
        self._scale = th.ones(shape, dtype=th.float32, requires_grad=False).to(device)
    else:
      raise NotImplementedError(self._impl)

  def __call__(self, reg, minent=None, maxent=None, update=True):
    if minent is not None and maxent is not None:
      lo = minent / reg.shape[-1]; hi = maxent / reg.shape[-1]
      reg = (reg - lo) / (hi - lo)
    update and self.update(reg)
    scale = self.scale()
    loss = scale * (-reg if self._inverse else reg)
    metrics = {
        'mean': reg.mean(), 'std': reg.std(),
        'scale_mean': scale.mean(), 'scale_std': scale.std()}
    return loss, metrics

  def scale(self):
    scale = self._scale
    scale = self._scale
    scale = self._scale
    if type(scale) is th.Tensor:
      scale = scale.detach()
    else:
      scale = th.tensor(scale).detach()
    return scale

  def update(self, reg):
    avg = reg.mean(list(range(len(reg.shape) - len(self._shape))))
    if self._impl == 'mult':
      below = avg < (1 / (1 + self._thres)) * self._target
      above = avg > (1 + self._thres) * self._target
      if self._inverse:
        below, above = above, below
      inside = ~below & ~above
      # NOTE: How does the original implementation work?!
      adjusted = (
          above.float() * self._scale * (1 + self._vel) +
          below.float() * self._scale / (1 + self._vel) +
          inside.float() * self._scale)
      self._scale = th.clip(adjusted, self._min, self._max)
      # if avg[0] < 0.55 and avg[0] > 0.45:
      # ipshell()
    else:
      raise NotImplementedError(self._impl)
    
class Ratio:
  def __init__(self, ratio):
    assert ratio >= 0, ratio
    self._ratio = ratio
    self._prev = None

  def __call__(self, step):
    step = int(step)
    if self._ratio == 0:
      return 0
    if self._prev is None:
      self._prev = step
      return 1
    repeats = int((step - self._prev) * self._ratio)
    self._prev += repeats / self._ratio
    return repeats