Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuestionAnsweringPipeline returns full context in Japanese #17706

Closed
3 of 4 tasks
KoichiYasuoka opened this issue Jun 15, 2022 · 6 comments · Fixed by #18010
Closed
3 of 4 tasks

QuestionAnsweringPipeline returns full context in Japanese #17706

KoichiYasuoka opened this issue Jun 15, 2022 · 6 comments · Fixed by #18010
Labels

Comments

@KoichiYasuoka
Copy link
Contributor

System Info

- `transformers` version: 4.19.4
- Platform: Linux-5.10.0-13-amd64-x86_64-with-glibc2.31
- Python version: 3.9.2
- Huggingface_hub version: 0.1.0
- PyTorch version (GPU?): 1.11.0+cu102 (False)

Who can help?

@Narsil @sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

QuestionAnsweringPipeline (almost always) returns full context in Japanese, for example:

from transformers import AutoTokenizer, AutoModelForQuestionAnswering, QuestionAnsweringPipeline
tokenizer = AutoTokenizer.from_pretrained("KoichiYasuoka/deberta-base-japanese-aozora-ud-head")
model = AutoModelForQuestionAnswering.from_pretrained("KoichiYasuoka/deberta-base-japanese-aozora-ud-head")
qap = QuestionAnsweringPipeline(tokenizer=tokenizer, model=model)
print(qap(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている"))

returns {'score': 0.9999955892562866, 'start': 0, 'end': 30, 'answer': '全学年にわたって小学校の国語の教科書に挿し絵が用いられている'}. On the other hand, directly with torch.argmax

import torch
from transformers import AutoTokenizer,AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("KoichiYasuoka/deberta-base-japanese-aozora-ud-head")
model = AutoModelForQuestionAnswering.from_pretrained("KoichiYasuoka/deberta-base-japanese-aozora-ud-head")
question = "国語"
context = "全学年にわたって小学校の国語の教科書に挿し絵が用いられている"
inputs = tokenizer(question,context, return_tensors="pt", return_offsets_mapping=True)
offsets = inputs.pop("offset_mapping").tolist()[0]
outputs = model(**inputs)
start, end = torch.argmax(outputs.start_logits), torch.argmax(outputs.end_logits)
print(context[offsets[start][0]:offsets[end][-1]])

the model returns the answer "教科書" correctly.

Expected behavior

Return the right answer "教科書" instead of full context.
@KoichiYasuoka
Copy link
Contributor Author

I suspect that "encoding" in Japanese models do not work at https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/question_answering.py#L452
but I'm vague how to fix.

@gante
Copy link
Member

gante commented Jun 15, 2022

Hi @KoichiYasuoka 👋 As per our issues guidelines, we reserve GitHub issues for bugs in the repository and/or feature requests. For any other requests, we'd like to invite you to use our forum 🤗

(Since the issue is about the quality of the output, it's probably model-related, and not a bug per se. In any case, if you suspect it is due to a bug in transformers, please add more information here)

@Narsil
Copy link
Contributor

Narsil commented Jul 4, 2022

Hi @KoichiYasuoka ,

This seems to be linked to the pipeline attempts to align on "words". The problem is that this japanese tokenizer does not ever cut on "words" so the whole context is a single word, so the realignment just forgets all about the actual answer, which is a bit sad.

I created a PR to include a new parameter to disable this so it can work on your use case (I personally think it should be the default but we cannot change this because of backward compatibility)

@KoichiYasuoka
Copy link
Contributor Author

Thank you @Narsil for creating new PR with align_to_words=False option. Well, can I use the option in the widget of deberta-base-japanese-aozora-ud-head page?

@Narsil
Copy link
Contributor

Narsil commented Jul 15, 2022

Hi, the PR is not merged yet, and it will take a few days before it lands on the API (API doesn't run master).

Afterwards, while being undocumented and thus maybe deactivated at anytime (though we rarely do this), you could send align_to_words: false within the parameters part of your query to the API.

Unfortunately the widget itself will not use parameters.

Does that answer your question ?

@github-actions
Copy link

github-actions bot commented Aug 8, 2022

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants