import torch as th
import numpy as np
from utils.cpp_utils import load_c_lib, c_ptr
import copy
import ctypes


class MSTConstructor:
    def __init__(self):
        self.tree_lib = load_c_lib('./utils/mst.cpp')

    def construct(self, w, device):
        w = w.detach()
        bs, n, _ = w.shape

        _w = np.array(copy.deepcopy(w).cpu()).astype(ctypes.c_double)
        _best_graphs = np.zeros((bs, n, n)).astype(ctypes.c_double)

        self.tree_lib.maximum_spanning_tree(c_ptr(_w), c_ptr(_best_graphs), bs, n)

        best_graphs = th.tensor(copy.deepcopy(_best_graphs)).to(device)

        best_graphs = best_graphs.float()

        return best_graphs
