From 724127a63acd9a819742cafc0f7638dc1e07b3a7 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Thu, 3 Aug 2023 02:20:36 -0700 Subject: [PATCH] Access graph structure for NSG Summary: It is not entirely trivial to access the NSG graph structure from Python (although it is a fixed size N-by-K matrix of vector ids). This diff adds an inspect_tools function to do that. Differential Revision: D48026775 fbshipit-source-id: e4179f0df92299bbe12bcd92eed201c46b6d435d --- contrib/inspect_tools.py | 13 +++++++++++++ faiss/impl/NSG.h | 2 +- faiss/python/swigfaiss.swig | 13 +++++++++++++ tests/test_contrib.py | 10 ++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/contrib/inspect_tools.py b/contrib/inspect_tools.py index 87928f4bb9..cc22ff5368 100644 --- a/contrib/inspect_tools.py +++ b/contrib/inspect_tools.py @@ -96,3 +96,16 @@ def get_flat_data(index): """ copy and return the data matrix in an IndexFlat """ xb = faiss.vector_to_array(index.codes).view("float32") return xb.reshape(index.ntotal, index.d) + + +def get_NSG_neighbors(nsg): + """ get the neighbor list for the vectors stored in the NSG structure, as + a N-by-K matrix of indices """ + graph = nsg.get_final_graph() + neighbors = np.zeros((graph.N, graph.K), dtype='int32') + faiss.memcpy( + faiss.swig_ptr(neighbors), + graph.data, + neighbors.nbytes + ) + return neighbors diff --git a/faiss/impl/NSG.h b/faiss/impl/NSG.h index e115b317fb..641a42f8cf 100644 --- a/faiss/impl/NSG.h +++ b/faiss/impl/NSG.h @@ -54,7 +54,7 @@ namespace nsg { template struct Graph { - node_t* data; ///< the flattened adjacency matrix + node_t* data; ///< the flattened adjacency matrix, size N-by-K int K; ///< nb of neighbors per node int N; ///< total nb of nodes bool own_fields; ///< the underlying data owned by itself or not diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 7ebc6624e5..852690622b 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -454,7 +454,20 @@ void gpu_sync_all_devices() %include %include + +%warnfilter(509) faiss::nsg::Graph< int >::at(int,int); + %include + +%template(NSG_Graph_int) faiss::nsg::Graph; + +// not using %shared_ptr to avoid mem leaks +%extend faiss::NSG { + faiss::nsg::Graph* get_final_graph() { + return $self->final_graph.get(); + } +} + %include #ifndef SWIGWIN diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 1982241142..057b043573 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -219,6 +219,16 @@ def test_make_LT(self): Ynew = lt.apply(X) np.testing.assert_equal(Yref, Ynew) + def test_NSG_neighbors(self): + # FIXME number of elements to add should be >> 100 + ds = datasets.SyntheticDataset(32, 0, 200, 10) + index = faiss.index_factory(ds.d, "NSG") + index.add(ds.get_database()) + neighbors = inspect_tools.get_NSG_neighbors(index.nsg) + # neighbors should be either valid indexes or -1 + np.testing.assert_array_less(-2, neighbors) + np.testing.assert_array_less(neighbors, ds.nb) + class TestRangeEval(unittest.TestCase):