Source code for gpytorch.nearest_neighbors.k_means_index

#!/usr/bin/env python3

from typing import List

import faiss
import numpy as np
import torch

from ._index import BaseIndex


[docs]class KMeansIndex(BaseIndex): """ This index performs K-Means clustering on a given feature set, computes neighboring blocks, enables evaluating block membership for test points, and enables reordering of the blocks based on block centroids. :param data: Features to cluster via K-Means, typically an n x 2 tensor of spatial lat-long coordinates. :param n_blocks: Number of desired clusters. Note that this does not guarantee similarly-sized clusters. :param n_neighbors: Number of neighboring clusters per cluster. """ def __init__(self, data: torch.tensor, n_blocks: int, n_neighbors: int, distance_metric): self.n_blocks = n_blocks self.n_neighbors = n_neighbors self.distance_metric = distance_metric self.centroids = None # this call executes set_blocks and set_neighbors, then superclass computes all dependent quantities super(KMeansIndex, self).__init__(set_blocks_kwargs={"data": data}, set_neighbors_kwargs={}) def _get_cluster_membership(self, data: torch.tensor) -> List[torch.LongTensor]: """ Determines which K-Means cluster each point in the provided data belongs to. :param data: Tensor for which to evaluate cluster membership. If any of these points are outside the domain of the points used to train the K-Means clusters, you may get nonsensical results. :return: List of tensors, where the ith tensor contains the indices of the points in data that belong to the ith K-Means cluster. """ blocks = [] block_per_point = torch.cdist(data, self.centroids).argsort()[:, 0] for block in range(len(self.centroids)): these_members = (block_per_point == block).nonzero().squeeze() blocks.append(these_members) return blocks def set_blocks(self, data: torch.tensor) -> List[torch.LongTensor]: # create and train faiss k-means object kmeans = faiss.Kmeans(data.shape[1], self.n_blocks, niter=10) kmeans.train(np.array(data.float())) # k-means gives centroids directly, so save centroids self.centroids = torch.tensor(kmeans.centroids) # determine indices of data points that belong to each cluster block and return return self._get_cluster_membership(data) def set_neighbors(self) -> List[torch.LongTensor]: # if there are no neighbors, we want a list of empty tensors if self.n_neighbors == 0: return [torch.LongTensor([]) for _ in range(0, self.n_blocks)] else: # get distance matrix and find ordered distances sorter = self.distance_metric(self.centroids, self.centroids).argsort().long() return [sorter[i][sorter[i] < i][0 : self.n_neighbors] for i in range(0, len(sorter))] def set_test_blocks(self, new_data: torch.tensor) -> List[torch.LongTensor]: # determine indices of new data points that belong to each cluster block and return return self._get_cluster_membership(new_data) def reorder(self, ordering_strategy): # new order is defined as some reordering of the K-means block centroids new_order = ordering_strategy(self.centroids) # reorder the instance attributes that depend on the ordering self.centroids = self.centroids[new_order] # reorder superclass attributes and recompute neighbors under new ordering super()._reorder(new_order)