Skip to content

Commit

Permalink
Fix test fetcher (doctest) + Idefics2's doc example (#30274)
Browse files Browse the repository at this point in the history
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 2817288 commit 63d6267
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
13 changes: 5 additions & 8 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,17 +1786,13 @@ def forward(
>>> from transformers import AutoProcessor, AutoModelForVision2Seq
>>> from transformers.image_utils import load_image
>>> DEVICE = "cuda:0"
>>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
>>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
>>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
>>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
>>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b-base")
>>> model = AutoModelForVision2Seq.from_pretrained(
... "HuggingFaceM4/idefics2-8b-base",
>>> ).to(DEVICE)
>>> model = AutoModelForVision2Seq.from_pretrained("HuggingFaceM4/idefics2-8b-base", device_map="auto")
>>> BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
>>> EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
Expand All @@ -1805,15 +1801,16 @@ def forward(
>>> prompts = [
... "<image>In this image, we can see the city of New York, and more specifically the Statue of Liberty.<image>In this image,",
... "In which city is that bridge located?<image>",
>>> ]
... ]
>>> images = [[image1, image2], [image3]]
>>> inputs = processor(text=prompts, padding=True, return_tensors="pt").to(DEVICE)
>>> inputs = processor(text=prompts, padding=True, return_tensors="pt").to("cuda")
>>> # Generate
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=500)
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
>>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> print(generated_texts)
['In this image, we can see the city of New York, and more specifically the Statue of Liberty. In this image, we can see the city of New York, and more specifically the Statue of Liberty.\n\n', 'In which city is that bridge located?\n\nThe bridge is located in the city of Pittsburgh, Pennsylvania.\n\n\nThe bridge is']
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down
2 changes: 1 addition & 1 deletion utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def get_all_doctest_files() -> List[str]:
# change to use "/" as path separator
test_files_to_run = ["/".join(Path(x).parts) for x in test_files_to_run]
# don't run doctest for files in `src/transformers/models/deprecated`
test_files_to_run = [x for x in test_files_to_run if "models/deprecated" not in test_files_to_run]
test_files_to_run = [x for x in test_files_to_run if "models/deprecated" not in x]

# only include files in `src` or `docs/source/en/`
test_files_to_run = [x for x in test_files_to_run if x.startswith(("src/", "docs/source/en/"))]
Expand Down

0 comments on commit 63d6267

Please sign in to comment.