-
Notifications
You must be signed in to change notification settings - Fork 99
/
evaluate.py
86 lines (79 loc) · 3.2 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
import requests
from requests.adapters import HTTPAdapter, Retry
import argparse
import pandas
def search(query_group):
ids = [] #the list of products we have labels for
for (_, row) in query_group.iterrows():
ids.append("id:{}".format(row['product_id']))
recall = " ".join(ids)
recall = "+({})".format(recall)
query = row['query']
query_id = row['query_id']
query_request = {
'yql': 'select id from product where userQuery() or ({targetHits:100}nearestNeighbor(title_embedding, q_title)) or ({targetHits:100}nearestNeighbor(description_embedding, q_description))',
'query': query,
'input.query(q_title)': 'embed(title, "%s")' % query,
'input.query(q_description)': 'embed(description, "%s")' % query,
'input.query(query_tokens)': 'embed(tokenizer, "%s")' %query,
'ranking': args.ranking,
'hits' : args.hits,
'timeout': '15s',
'recall': recall ,
'ranking.softtimeout.enable': 'false'
}
response = session.post(args.endpoint, json=query_request,timeout=120)
if response.ok:
json_result = response.json()
root = json_result['root']
total_count = root['fields']['totalCount']
assert total_count == len(ids) #make sure we rank all
if total_count > 0:
pos = 1
for hit in root['children']:
id = hit['fields']['id']
relevance = hit['relevance']
doc = {
"query_id": query_id,
"iteration": "Q0",
"product_id": id,
"position": pos,
"score": relevance,
"runid": args.ranking
}
responses.append(doc)
pos+=1
else:
print("request failed " + str(response.json()))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--endpoint", type=str, required=True)
parser.add_argument("--ranking", type=str, required=True)
parser.add_argument("--example_file", type=str, required=True)
parser.add_argument("--hits", type=int, default=400)
parser.add_argument("--certificate", type=str)
parser.add_argument("--key", type=str)
global args
args = parser.parse_args()
global session
session = requests.Session()
retries = Retry(total=10, connect=10,
backoff_factor=0.3,
status_forcelist=[ 500, 503, 504, 429 ]
)
session.mount('https://', HTTPAdapter(max_retries=retries))
session.mount('http://', HTTPAdapter(max_retries=retries))
if args.certificate and args.key:
session.cert = (args.certificate, args.key)
global responses
responses = []
df_examples = pandas.read_parquet(args.example_file)
df_examples = df_examples[df_examples['split'] == "test"]
df_examples = df_examples[df_examples['product_locale'] == "us"]
df_examples = df_examples[df_examples['small_version'] == 1]
df_examples.groupby("query_id").apply(search)
df_result = pandas.DataFrame.from_records(responses)
df_result.to_csv(args.ranking + ".run", index=False, header=False, sep=' ')
if __name__ == "__main__":
main()