-
Notifications
You must be signed in to change notification settings - Fork 46
/
most_popular.py
40 lines (29 loc) · 1.05 KB
/
most_popular.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
import fire
from tqdm import tqdm
from arena_util import load_json
from arena_util import write_json
from arena_util import remove_seen
from arena_util import most_popular
class MostPopular:
def _generate_answers(self, train, questions):
_, song_mp = most_popular(train, "songs", 200)
_, tag_mp = most_popular(train, "tags", 100)
answers = []
for q in tqdm(questions):
answers.append({
"id": q["id"],
"songs": remove_seen(q["songs"], song_mp)[:100],
"tags": remove_seen(q["tags"], tag_mp)[:10],
})
return answers
def run(self, train_fname, question_fname):
print("Loading train file...")
train = load_json(train_fname)
print("Loading question file...")
questions = load_json(question_fname)
print("Writing answers...")
answers = self._generate_answers(train, questions)
write_json(answers, "results/results.json")
if __name__ == "__main__":
fire.Fire(MostPopular)