Skip to content

Commit

Permalink
Adding a new align_to_words param to qa pipeline. (huggingface#18010)
Browse files Browse the repository at this point in the history
* Adding a new `align_to_words` param to qa pipeline.

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Import protection.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
2 people authored and amyeroberts committed Oct 18, 2022
1 parent 3860839 commit 5034072
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,48 @@ def get_indices(
end_index = enc.offsets[e][1]
return start_index, end_index

def decode(
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
) -> Tuple:
"""
Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the
actual answer.
In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or
answer end position being before the starting position. The method supports output the k-best answer through
the topk argument.
Args:
start (`np.ndarray`): Individual start probabilities for each token.
end (`np.ndarray`): Individual end probabilities for each token.
topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer
"""
# Ensure we have batch axis
if start.ndim == 1:
start = start[None]

if end.ndim == 1:
end = end[None]

# Compute the score of each tuple(start, end) to be the real answer
outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))

# Remove candidate with end < start and end - start > max_answer_len
candidates = np.tril(np.triu(outer), max_answer_len - 1)

# Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
scores_flat = candidates.flatten()
if topk == 1:
idx_sort = [np.argmax(scores_flat)]
elif len(scores_flat) < topk:
idx_sort = np.argsort(-scores_flat)
else:
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
return start_index, end_index

def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]:
"""
When decoding from token probabilities, this method maps token indexes to actual word in the initial context.
Expand Down

0 comments on commit 5034072

Please sign in to comment.