# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_entropy_reg.ipynb.

# %% auto 0
__all__ = ['sx', 'sy', 'alpha', 'dist', 'd2', 'p', 'l', 'd3_l', 'd3', 'labels', 'a', 'b', 'c', 'stats', 'add_label']

# %% ../nbs/02_entropy_reg.ipynb 8
'''
Ways to combine two categorical distributions
'''
sx, sy = 64, 2
alpha = 0.5

dist = torch.distributions.Independent(tools.OneHotDist(torch.ones((sx,sy), device='cuda')), 1)
# d2 = torch.distributions.Independent(tools.OneHotDist(torch.ones_like(dist.base_dist.probs, device='cuda')), 1)
d2 = torch.distributions.Independent(tools.OneHotDist(torch.rand_like(dist.base_dist.probs, device='cuda')), 1)

from fastcore.all import *
p = dist.base_dist.probs.detach().cpu().numpy(); l = dist.base_dist.logits.detach().cpu().numpy()
l, np.log(p), p, np.exp(l), test_close(l, np.log(p))

d3_l = (alpha * dist.base_dist.logits) + (1 - alpha) * d2.base_dist.logits
# d3_p = dist.base_dist.probs * d2.base_dist.probs; test_close(d3_l, torch.log(d3_p))

d3 = torch.distributions.Independent(tools.OneHotDist(d3_l), 1)

# Plot the three distributions


import matplotlib.pyplot as plt, matplotlib.patches as mpatches
labels = []
def add_label(violin, label):
    color = violin["bodies"][0].get_facecolor().flatten()
    labels.append((mpatches.Patch(color=color), label))
a, b, c = dist.base_dist.probs.detach().cpu().numpy(), d2.base_dist.probs.detach().cpu().numpy(), d3.base_dist.probs.detach().cpu().numpy()
add_label(plt.violinplot(dist.base_dist.probs.detach().cpu().numpy(), showmeans=True), 'dist')
add_label(plt.violinplot(d2.base_dist.probs.detach().cpu().numpy(), showmeans=True), 'd2')
add_label(plt.violinplot(d3.base_dist.probs.detach().cpu().numpy(), showmeans=True), 'd3')
plt.legend(*zip(*labels), loc=2)

stats = lambda x: (x.min(), x.mean(), x.max(), x.std())

print(stats(a))
print(stats(b))
print(stats(c))
