-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathextract_claims.py
347 lines (316 loc) · 14.6 KB
/
extract_claims.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import re
import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from .stat_calculator import StatCalculator
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.utils.model import WhiteboxModel
from .claim_level_prompts import CLAIM_EXTRACTION_PROMPTS, MATCHING_PROMPTS
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
log = logging.getLogger("lm_polygraph")
@dataclass
class Claim:
claim_text: str
# The sentence of the generation, from which the claim was extracted
sentence: str
# Indices in the original generation of the tokens, which are related to the current claim
aligned_token_ids: List[int]
class ClaimsExtractor(StatCalculator):
"""
Extracts claims from the text of the model generation.
"""
def __init__(
self,
openai_chat: OpenAIChat,
sent_separators: str = ".?!。?!\n",
language: str = "en",
progress_bar: bool = False,
extraction_prompts: Dict[str, str] = CLAIM_EXTRACTION_PROMPTS,
matching_prompts: Dict[str, str] = MATCHING_PROMPTS,
n_threads: int = 1,
):
super().__init__()
log.info(f"Initializing ClaimsExtractor with language={language}")
self.language = language
self.openai_chat = openai_chat
self.sent_separators = sent_separators
self.progress_bar = progress_bar
self.extraction_prompts = extraction_prompts
self.matching_prompts = matching_prompts
self.n_threads = n_threads
@staticmethod
def meta_info() -> Tuple[List[str], List[str]]:
return (
[
"claims",
"claim_texts_concatenated",
"claim_input_texts_concatenated",
],
[
"greedy_texts",
"greedy_tokens",
],
)
def __call__(
self,
dependencies: Dict[str, object],
texts: List[str],
model: WhiteboxModel,
*args,
**kwargs,
) -> Dict[str, List]:
"""
Extracts the claims out of each generation text.
Parameters:
dependencies (Dict[str, object]): input statistics, which includes:
* 'greedy_log_probs' (List[List[float]]): log-probabilities of the generation tokens.
texts (List[str]): Input texts batch used for model generation.
model (Model): Model used for generation.
Returns:
Dict[str, List]: dictionary with :
* 'claims' (List[List[lm_polygraph.stat_calculators.extract_claims.Claim]]):
list of claims for each input text;
* 'claim_texts_concatenated' (List[str]): list of all textual claims extracted;
* 'claim_input_texts_concatenated' (List[str]): for each claim in
claim_texts_concatenated, corresponding input text.
"""
greedy_texts = dependencies["greedy_texts"]
greedy_tokens = dependencies["greedy_tokens"]
claims: List[List[Claim]] = []
claim_texts_concatenated: List[str] = []
claim_input_texts_concatenated: List[str] = []
with ThreadPoolExecutor(max_workers=self.n_threads) as executor:
claims = list(
tqdm(
executor.map(
self.claims_from_text,
greedy_texts,
greedy_tokens,
[model.tokenizer] * len(greedy_texts),
),
total=len(greedy_texts),
desc="Extracting claims",
disable=not self.progress_bar,
)
)
for c in claims:
for claim in c:
claim_texts_concatenated.append(claim.claim_text)
claim_input_texts_concatenated.append(texts[0])
return {
"claims": claims,
"claim_texts_concatenated": claim_texts_concatenated,
"claim_input_texts_concatenated": claim_input_texts_concatenated,
}
def claims_from_text(self, text: str, tokens: List[int], tokenizer) -> List[Claim]:
sentences = []
for s in re.split(f"[{self.sent_separators}]", text):
if len(s) > 0:
sentences.append(s)
if len(text) > 0 and text[-1] not in self.sent_separators:
# Remove last unfinished sentence, because extracting claims
# from unfinished sentence may lead to hallucinated claims.
sentences = sentences[:-1]
sent_start_token_idx, sent_end_token_idx = 0, 0
sent_start_idx, sent_end_idx = 0, 0
claims = []
for s in sentences:
# Find sentence location in text: text[sent_start_idx:sent_end_idx]
while not text[sent_start_idx:].startswith(s):
sent_start_idx += 1
while not text[:sent_end_idx].endswith(s):
sent_end_idx += 1
# Iteratively decode tokenized text until decoded sequence length is
# greater or equal to the starting position of current sentence.
# Find sentence location in tokens: tokens[sent_start_token_idx:sent_end_token_idx]
while len(tokenizer.decode(tokens[:sent_start_token_idx])) < sent_start_idx:
sent_start_token_idx += 1
while len(tokenizer.decode(tokens[:sent_end_token_idx])) < sent_end_idx:
sent_end_token_idx += 1
# Extract claims from current sentence
for c in self._claims_from_sentence(
s, tokens[sent_start_token_idx:sent_end_token_idx], tokenizer
):
# Correct aligned tokens positions from sentence-level to generation-level
for i in range(len(c.aligned_token_ids)):
c.aligned_token_ids[i] += sent_start_token_idx
claims.append(c)
return claims
def _claims_from_sentence(
self,
sent: str,
sent_tokens: List[int],
tokenizer,
) -> List[Claim]:
# Extract claims with specific prompt
extracted_claims = self.openai_chat.ask(
self.extraction_prompts[self.language].format(sent=sent)
)
claims = []
for claim_text in extracted_claims.split("\n"):
# Bad claim_text example:
# - There aren't any claims in this sentence.
if not claim_text.startswith("- "):
continue
if "there aren't any claims" in claim_text.lower():
continue
# remove '- ' in the beginning
claim_text = claim_text[2:].strip()
# Get words which matches the claim using specific prompt
# Example:
# sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.'
# claim = 'Lanny Flaherty was born on December 18, 1949.'
# GPT response: 'Lanny, Flaherty, born, on, December, 18, 1949'
# match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949']
chat_ask = self.matching_prompts[self.language].format(
sent=sent,
claim=claim_text,
)
match_words = self.openai_chat.ask(chat_ask)
# comma has a different form in Chinese and space works better
if self.language == "zh":
match_words = match_words.strip().split(" ")
else:
match_words = match_words.strip().split(",")
match_words = list(map(lambda x: x.strip(), match_words))
# Try to highlight matched symbols in sent
if self.language == "zh":
match_string = self._match_string_zh(sent, match_words)
else:
match_string = self._match_string(sent, match_words)
if match_string is None:
continue
# Get token positions which intersect with highlighted regions, that is, correspond to the claim
aligned_token_ids = self._align(sent, match_string, sent_tokens, tokenizer)
if len(aligned_token_ids) == 0:
continue
claims.append(
Claim(
claim_text=claim_text,
sentence=sent,
aligned_token_ids=aligned_token_ids,
)
)
return claims
def _match_string(self, sent: str, match_words: List[str]) -> Optional[str]:
"""
Greedily matching words from `match_words` to `sent`.
Parameters:
sent (str): sentence string
match_words (List[str]): list of words from sent, in the same order they appear in it.
Returns:
Optional[str]: string of length len(sent), for each symbol in sent, '^' if it contains in one
of the match_words if aligned to sent, ' ' otherwise.
Returns None if matching failed, e.g. due to words in match_words, which are not present
in sent, or of the words are specified not in the same order they appear in the sentence.
Example:
sent = 'Lanny Flaherty is an American actor born on December 18, 1949, in Pensacola, Florida.'
match_words = ['Lanny', 'Flaherty', 'born', 'on', 'December', '18', '1949']
return '^^^^^ ^^^^^^^^ ^^^^ ^^ ^^^^^^^^ ^^ ^^^^ '
"""
sent_pos = 0 # pointer to the sentence
match_words_pos = 0 # pointer to the match_words list
# Iteratively construct match_str with highlighted symbols, start with empty string
match_str = ""
while sent_pos < len(sent):
# Check if current word cur_word can be located in sent[sent_pos:sent_pos + len(cur_word)]:
# 1. check if symbols around word position are not letters
check_boundaries = False
if sent_pos == 0 or not sent[sent_pos - 1].isalpha():
check_boundaries = True
if check_boundaries and match_words_pos < len(match_words):
cur_match_word = match_words[match_words_pos]
right_idx = sent_pos + len(cur_match_word)
if right_idx < len(sent):
check_boundaries = not sent[right_idx].isalpha()
# 2. check if symbols in word position are the same as cur_word
if check_boundaries and sent[sent_pos:].startswith(cur_match_word):
# Found match at sent[sent_pos] with cur_word
len_w = len(cur_match_word)
sent_pos += len_w
# Highlight this position in match string
match_str += "^" * len_w
match_words_pos += 1
continue
# No match at sent[sent_pos], continue with the next position
sent_pos += 1
match_str += " "
if match_words_pos < len(match_words):
# Didn't match all words to the sentence.
# Possibly because the match words are in the wrong order or are not present in sentence.
return None
return match_str
def _match_string_zh(self, sent: str, match_words: List[str]) -> Optional[str]:
# Greedily matching characters from `match_words` to `sent` for Chinese.
# Returns None if matching failed, e.g. due to characters in match_words, which are not present
# in sent, or if the characters are not in the same order they appear in the sentence.
#
# Example:
# sent = '爱因斯坦也是一位和平主义者。'
# match_words = ['爱因斯坦', '是', '和平', '主义者']
# return '^^^^ ^ ^^^^'
last = 0 # pointer to the sentence
last_match = 0 # pointer to the match_words list
match_str = ""
# Iterate through each character in the input Chinese text
for char in sent:
# Check if the current character matches the next character in match_words[last_match]
if last_match < len(match_words) and char == match_words[last_match][last]:
# Match found, update pointers and match_str
match_str += "^"
last += 1
if last == len(match_words[last_match]):
last = 0
last_match += 1
else:
# No match, append a space to match_str
match_str += " "
# Check if all characters in match_words have been matched
if last_match < len(match_words):
return None # Didn't match all characters to the sentence
return match_str
def _align(
self,
sent: str,
match_str: str,
sent_tokens: List[int],
tokenizer,
) -> List[int]:
"""
Identifies token indices in `sent_tokens` that align with matching characters (marked by '^')
in `match_str`. All tokens, which textual representations intersect with any of matching
characters, are included. Partial intersections should be uncommon in practice.
Args:
sent: the original sentence.
match_str: a string of the same length as `sent` where '^' characters indicate matches.
sent_tokens: a list of token ids representing the tokenized version of `sent`.
tokenizer: the tokenizer used to decode tokens.
Returns:
A list of integers representing the indices of tokens in `sent_tokens` that align with
matching characters in `match_str`.
"""
sent_pos = 0
cur_token_i = 0
# Iteratively find position of each new token.
aligned_token_ids = []
while sent_pos < len(sent) and cur_token_i < len(sent_tokens):
cur_token_text = tokenizer.decode(sent_tokens[cur_token_i])
# Try to find the position of cur_token_text in sentence, possibly in sent[sent_pos]
if len(cur_token_text) == 0:
# Skip non-informative token
cur_token_i += 1
continue
if sent[sent_pos:].startswith(cur_token_text):
# If the match string corresponding to the token contains matches, add to answer
if any(
t == "^"
for t in match_str[sent_pos : sent_pos + len(cur_token_text)]
):
aligned_token_ids.append(cur_token_i)
cur_token_i += 1
sent_pos += len(cur_token_text)
else:
# Continue with the same token and next position in the sentence.
sent_pos += 1
return aligned_token_ids