# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_md_gauss.ipynb.

# %% auto 0
__all__ = ['update_mean_covariance', 'update_multivariate_gaussian_prior', 'Drone']

# %% ../nbs/01_md_gauss.ipynb 2
import matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import scipy.stats as stats
import collections
import fastcore.all as fc

# %% ../nbs/01_md_gauss.ipynb 3
def update_mean_covariance(old_mean, old_covariance, new_data, n):
    """
    openai generated function
    Update the mean and covariance matrix with a new data point.

    :param old_mean: The previous mean vector.
    :param old_covariance: The previous covariance matrix.
    :param new_data: The new data point (as an array).
    :param n: The number of data points seen so far, including the new one.
    :return: Updated mean and covariance matrix.
    """
    new_mean = (old_mean * (n - 1) + new_data) / n
    
    if n == 1:
        return new_mean, np.zeros_like(old_covariance)

    new_covariance = ((n - 2) / (n - 1)) * old_covariance + \
                     (1 / n) * np.outer(new_data - new_mean, new_data - old_mean)

    return new_mean, new_covariance

import numpy as np

def update_multivariate_gaussian_prior(prior_mean, prior_covariance, data_mean, data_covariance, sample_size):
    """
    openai generated function
    Update a multivariate Gaussian prior given new data.

    :param prior_mean: Mean vector of the prior Gaussian distribution.
    :param prior_covariance: Covariance matrix of the prior Gaussian distribution.
    :param data_mean: Mean vector of the observed data.
    :param data_covariance: Covariance matrix of the observed data.
    :param sample_size: The number of data points observed.
    :return: Updated mean vector and covariance matrix.
    """
    # Compute precision (inverse of covariance)
    prior_precision = np.linalg.inv(prior_covariance)
    data_precision = np.linalg.inv(data_covariance) * sample_size

    # Update precision and mean
    updated_precision = prior_precision + data_precision
    updated_mean = np.linalg.inv(updated_precision).dot(prior_precision.dot(prior_mean) + data_precision.dot(data_mean))
    updated_covariance = np.linalg.inv(updated_precision)

    return updated_mean, updated_covariance

    # # Example usage:
    # prior_mean = np.array([0.0, 0.0])
    # prior_covariance = np.array([[1.0, 0.0], [0.0, 1.0]])
    # data_mean = np.array([0.5, 0.5])
    # data_covariance = np.array([[0.25, 0.1], [0.1, 0.25]])
    # sample_size = 10

    # updated_mean, updated_covariance = update_multivariate_gaussian_prior(
    #     prior_mean, prior_covariance, data_mean, data_covariance, sample_size
    # )

class Drone:
    def __init__(self, dim=2) -> None:
        fc.store_attr()
        self.mean = np.zeros(dim); self.cov = np.eye(dim); self.n = 0
        self.history = collections.deque(maxlen=1000); self.history.append(self.mean)

    def update(self, new_data):
        self.n += 1
        self.mean, self.cov = update_mean_covariance(self.mean, self.cov, new_data, self.n); self.history.append(self.mean)

    def update_with_batch(self, new_data):
        num_samples = new_data.shape[0]
        batch_mean = np.mean(new_data, axis=0); batch_cov = np.cov(new_data.T)
        self.mean, self.cov = update_multivariate_gaussian_prior(self.mean, self.cov, batch_mean, batch_cov, num_samples); self.history.append(self.mean)

