From 9016ecb1d9e58f3a4b0a8cb9b3917c5672866d1c Mon Sep 17 00:00:00 2001 From: LukasWestholt Date: Fri, 4 Oct 2024 19:51:43 +0200 Subject: [PATCH] add recall metric qdrant/vector-db-benchmark#138 --- engine/base_client/search.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/engine/base_client/search.py b/engine/base_client/search.py index 3626191e..cb110633 100644 --- a/engine/base_client/search.py +++ b/engine/base_client/search.py @@ -47,11 +47,18 @@ def _search_one(cls, query: Query, top: Optional[int] = None): end = time.perf_counter() precision = 1.0 + recall = 1.0 if query.expected_result: ids = set(x[0] for x in search_res) - precision = len(ids.intersection(query.expected_result[:top])) / top - return precision, end - start + # precision equals True Positives / (True Positives + False Positives) + # recall equals True Positives / (True Positives + False Negatives) + # See https://en.wikipedia.org/wiki/Precision_and_recall + true_positives = len(ids.intersection(query.expected_result[:top])) + precision = true_positives / top # 1 means that the results consist of expected elements. + recall = true_positives / len(query.expected_result) # 1 means that all expected elements are in the results. + + return precision, recall, end - start def search_all( self, @@ -71,7 +78,7 @@ def search_all( if parallel == 1: start = time.perf_counter() - precisions, latencies = list( + precisions, recalls, latencies = list( zip(*[search_one(query) for query in tqdm.tqdm(queries)]) ) else: @@ -90,7 +97,7 @@ def search_all( if parallel > 10: time.sleep(15) # Wait for all processes to start start = time.perf_counter() - precisions, latencies = list( + precisions, recalls, latencies = list( zip(*pool.imap_unordered(search_one, iterable=tqdm.tqdm(queries))) ) @@ -109,6 +116,7 @@ def search_all( "p95_time": np.percentile(latencies, 95), "p99_time": np.percentile(latencies, 99), "precisions": precisions, + "recalls": recalls, "latencies": latencies, }