From 50340722cd1445ded7be31f9efb860f7dee06571 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 9 Aug 2022 18:50:02 +0200 Subject: [PATCH] Adding a new `align_to_words` param to qa pipeline. (#18010) * Adding a new `align_to_words` param to qa pipeline. * Update src/transformers/pipelines/question_answering.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Import protection. Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../pipelines/question_answering.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 6a1a0011c5efc1..0d3c511a807c46 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -603,6 +603,48 @@ def get_indices( end_index = enc.offsets[e][1] return start_index, end_index + def decode( + self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray + ) -> Tuple: + """ + Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the + actual answer. + + In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or + answer end position being before the starting position. The method supports output the k-best answer through + the topk argument. + + Args: + start (`np.ndarray`): Individual start probabilities for each token. + end (`np.ndarray`): Individual end probabilities for each token. + topk (`int`): Indicates how many possible answer span(s) to extract from the model output. + max_answer_len (`int`): Maximum size of the answer to extract from the model's output. + undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer + """ + # Ensure we have batch axis + if start.ndim == 1: + start = start[None] + + if end.ndim == 1: + end = end[None] + + # Compute the score of each tuple(start, end) to be the real answer + outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1)) + + # Remove candidate with end < start and end - start > max_answer_len + candidates = np.tril(np.triu(outer), max_answer_len - 1) + + # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) + scores_flat = candidates.flatten() + if topk == 1: + idx_sort = [np.argmax(scores_flat)] + elif len(scores_flat) < topk: + idx_sort = np.argsort(-scores_flat) + else: + start_index = enc.offsets[s][0] + end_index = enc.offsets[e][1] + return start_index, end_index + def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]: """ When decoding from token probabilities, this method maps token indexes to actual word in the initial context.