-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.py
171 lines (136 loc) · 5.91 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import time, re, heapq, math, random
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from collections import Counter, defaultdict
class QueryEvaluator:
def __init__(self):
self.stemmer = SnowballStemmer('english')
self.stop_words = set(stopwords.words('english'))
# open docID title file
self.N_doc = 731739
def OneWordQuery(self, query, term, fields, k):
posting_list = query.split(';')
heap = []
# print(posting_list)
for posting in posting_list:
if not fields:
nos = [int(u) for u in re.findall(r'\d+', posting)]
if not nos:
break
title_factor = 10 if 't' in posting else 1
entry = (-sum(nos[1:])*title_factor, nos[0])
heap.append(entry)
else:
nos = [int(u) for u in re.findall(r'\d+', posting)]
if not nos:
break
field_p = re.findall(r'[a-z]', posting)
sum_ = 0
for i in range(len(nos)-1):
if field_p[i] in fields[term]:
sum_ += 100*nos[i+1]
else:
sum_ += nos[i+1]/10
entry = (-sum_, nos[0])
heap.append(entry)
heapq.heapify(heap)
docids = [doc[1] for doc in heap]
return docids if len(docids) < k else docids[:k]
def extractPosting(self, token):
fname = token[:3]
fil = 'final-index/' + fname
print(fil)
fp = open(fil, 'r')
while True:
line = fp.readline()
line = line.split(';', 1)
if line[0] == token:
return [line[0], line[1].strip()]
return ''
def MultiWordQuery(self, query_vector, query_tokens, posting_list, fields, k):
tfidf_vectors = defaultdict(lambda: [0] * len(query_tokens))
posting_list = [u.split(';') for u in posting_list]
docset = None
idf = []
for lis in posting_list:
idf.append(self.N_doc/len(lis))
docIDs = set()
for posting in lis:
if posting:
docIDs.add(int(re.split(r'([a-z]+)', posting)[0]))
if not docset:
docset = docIDs
else:
docset = set.intersection(docset, set(docIDs))
tfidf_scores = [{} for _ in range(len(query_tokens))]
for i in range(len(posting_list)):
for posting in posting_list[i]:
l = re.split(r'([a-z]+)',posting)
if l[0] != '' and int(l[0]) in docset:
if l[0] == '144657':
print(posting)
tf = sum([int(l[j]) for j in range(2, len(l), 2)])
tf = (1 + math.log10(tf))**2
if not fields:
if 't' in l:
tf = tf **3
if 'i' in l:
tf *= 2
if 'l' in l:
tf = tf ** 0.5
else:
tf /= 3
if fields:
for j in range(1, len(l), 2):
if l[j] in fields[query_tokens[i]]:
tf = tf**3
else:
tf = tf**0.8
tfidf = tf * idf[i]
tfidf_scores[i][int(l[0])] = tfidf
heap = []
# calculate IIIlarity scores
for docID in docset:
vector = [tfidf_scores[i][docID] for i in range(len(query_tokens))]
docscore = sum([vector[i]*query_vector[i] for i in range(len(vector))])
heap.append((-docscore, docID))
heapq.heapify(heap)
print(heap[:k])
docids = [doc[1] for doc in heap]
return docids if len(docids) < k else docids[:k]
def evaluateQuery(self, query, k):
query = query.lower()
query_fields = re.findall(r'[tbcirl]:', query)
query_token_fields = defaultdict(list)
if query_fields:
s = query
for field in reversed(query_fields):
splitted = s.split(field)
tokens = self.processText(splitted[-1])
for token in tokens:
query_token_fields[token].append(field[0])
s = splitted[0]
query_tokens = list(query_token_fields.keys())
if len(query_tokens) > 1:
query_vector = [len(query_token_fields[t]) for t in query_tokens]
else:
query_tokens = self.processText(query)
q_count = Counter(query_tokens)
query_token_fields = None
query_tokens = list(set(query_tokens))
if len(query_tokens) > 1:
query_vector = [q_count[token] for token in query_tokens]
query_pl = [] #posting lists related to the query
for query_term in query_tokens:
posting_list = self.extractPosting(query_term)
if len(query_tokens) == 1:
if posting_list != "":
return self.OneWordQuery(posting_list[1], posting_list[0], query_token_fields, k)
else:
query_pl.append(posting_list[1])
return self.MultiWordQuery(query_vector, query_tokens, query_pl, query_token_fields, k)
def processText(self, text):
toks = re.findall(r"[\w']{3,}", text.replace("'", "").replace("_", ""))
nsw = [word for word in toks if word not in self.stop_words]
stemmed_toks = [self.stemmer.stem(word) for word in nsw]
return stemmed_toks