diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e056a902..ba74bb96 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -595,6 +595,10 @@ namespace hnswlib { std::ifstream input(location, std::ios::binary); + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: input.seekg(0,input.end); std::streampos total_filesize=input.tellg(); @@ -625,16 +629,15 @@ namespace hnswlib { fstdistfunc_ = s->get_dist_func(); dist_func_param_ = s->get_dist_func_param(); - /// Legacy, check that everything is ok - - bool old_index=false; - auto pos=input.tellg(); + + + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_,input.cur); for (size_t i = 0; i < cur_element_count; i++) { if(input.tellg() < 0 || input.tellg()>=total_filesize){ - old_index = true; - break; + throw std::runtime_error("Index seems to be corrupted or unsupported"); } unsigned int linkListSize; @@ -644,23 +647,21 @@ namespace hnswlib { } } - // check if file is ok, if not this is either corrupted or old index + // throw exception if it either corrupted or old index if(input.tellg()!=total_filesize) - old_index = true; + throw std::runtime_error("Index seems to be corrupted or unsupported"); - if (old_index) { - std::cerr << "Warning: loading of old indexes will be deprecated before 2019.\n" - << "Please resave the index in the new format.\n"; - } input.clear(); + + /// Optional check end + input.seekg(pos,input.beg); data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - if(old_index) - input.seekg(((max_elements_-cur_element_count) * size_data_per_element_), input.cur); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -691,6 +692,14 @@ namespace hnswlib { input.read(linkLists_[i], linkListSize); } } + + has_deletions_=false; + + for (size_t i = 0; i < cur_element_count; i++) { + if(isMarkedDeleted(i)) + has_deletions_=true; + } + input.close(); return; diff --git a/python_bindings/tests/bindings_test_labels.py b/python_bindings/tests/bindings_test_labels.py index b351935b..f629ab29 100644 --- a/python_bindings/tests/bindings_test_labels.py +++ b/python_bindings/tests/bindings_test_labels.py @@ -3,10 +3,12 @@ class RandomSelfTestCase(unittest.TestCase): def testRandomSelf(self): + for idx in range(16): print("\n**** Index save-load test ****\n") import hnswlib import numpy as np - + + np.random.seed(idx) dim = 16 num_elements = 10000 @@ -95,8 +97,8 @@ def testRandomSelf(self): p.mark_deleted(l[0]) labels2, _ = p.knn_query(data2, k=1) items=p.get_items(labels2) - diff_with_gt_labels=np.max(np.abs(data2-items)) - self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-4) # console + diff_with_gt_labels=np.mean(np.abs(data2-items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-3) # console labels1_after, _ = p.knn_query(data1, k=1) @@ -106,6 +108,18 @@ def testRandomSelf(self): self.assertTrue(False) print("All the data in data1 are removed") + # checking saving/loading index with elements marked as deleted + p.save_index("with_deleted.bin") + p = hnswlib.Index(space='l2', dim=dim) + p.load_index("with_deleted.bin") + p.set_ef(100) + + labels1_after, _ = p.knn_query(data1, k=1) + for la in labels1_after: + for lb in labels1: + if la[0] == lb[0]: + self.assertTrue(False) + if __name__ == "__main__": diff --git a/python_bindings/tests/bindings_test_resize.py b/python_bindings/tests/bindings_test_resize.py index 5e798164..9411af64 100644 --- a/python_bindings/tests/bindings_test_resize.py +++ b/python_bindings/tests/bindings_test_resize.py @@ -3,7 +3,7 @@ class RandomSelfTestCase(unittest.TestCase): def testRandomSelf(self): - for idx in range(32): + for idx in range(16): print("\n**** Index resize test ****\n") import hnswlib import numpy as np