import numpy as np
import torch
import random

class ReplayBuffer(object):
	def __init__(self, state_dim, action_dim, max_size=int(1e6)):
		self.max_size = max_size
		self.ptr = 0
		self.size = 0

		self.state = np.zeros((max_size, state_dim))
		self.action = np.zeros((max_size, action_dim))
		self.next_state = np.zeros((max_size, state_dim))
		self.reward = np.zeros((max_size, 1))
		self.done = np.zeros((max_size, 1))

		self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

	def add(self, state, action, reward, next_state, done):
		self.state[self.ptr] = state
		self.action[self.ptr] = action
		self.reward[self.ptr] = reward
		self.next_state[self.ptr] = next_state
		self.done[self.ptr] = done

		self.ptr = (self.ptr + 1) % self.max_size
		self.size = min(self.size + 1, self.max_size)

	def sample(self, batch_size):
		ind = np.random.randint(0, self.size, size=batch_size)

		return (
			torch.FloatTensor(self.state[ind]).to(self.device),
			torch.FloatTensor(self.action[ind]).to(self.device),
			torch.FloatTensor(self.reward[ind]).to(self.device),
			torch.FloatTensor(self.next_state[ind]).to(self.device),
			torch.FloatTensor(self.done[ind]).to(self.device)
		)

	def make_mini_batch(self, sample_size, batch_size):
		full_batch_size = self.size
		full_indices = [i for i in range(full_batch_size)]
		#full_indices = np.arange(full_batch_size)
		#np.random.shuffle(full_indices)

		full_batch_size = min(full_batch_size, sample_size)

		full_indices_selected = random.sample(full_indices, full_batch_size)
		full_indices_selected = np.array(full_indices_selected)

		for i in range(full_batch_size // batch_size):
			indices = full_indices_selected[batch_size*i : batch_size*(i+1)]
			state = torch.FloatTensor(self.state[indices]).to(self.device)
			action = torch.FloatTensor(self.action[indices]).to(self.device)
			reward = torch.FloatTensor(self.reward[indices]).to(self.device)
			next_state = torch.FloatTensor(self.next_state[indices]).to(self.device)
			done = torch.FloatTensor(self.done[indices]).to(self.device)

			yield [state, action, reward, next_state, done]


	