Skip to content

Commit

Permalink
Save image_processor while saving pipeline (ImageSegmentationPipeline) (
Browse files Browse the repository at this point in the history
huggingface#25884)

* Save image_processor while saving pipeline (ImageSegmentationPipeline)

* Fix black issues
  • Loading branch information
raghavanone authored and EduardoPach committed Nov 18, 2023
1 parent 4da85fa commit 4dd52ec
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,9 @@ def save_pretrained(self, save_directory: str, safe_serialization: bool = False)
if self.feature_extractor is not None:
self.feature_extractor.save_pretrained(save_directory)

if self.image_processor is not None:
self.image_processor.save_pretrained(save_directory)

if self.modelcard is not None:
self.modelcard.save_pretrained(save_directory)

Expand Down
15 changes: 15 additions & 0 deletions tests/pipelines/test_pipelines_image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
import tempfile
import unittest
from typing import Dict

Expand Down Expand Up @@ -714,3 +715,17 @@ def test_oneformer(self):
},
],
)

def test_save_load(self):
model_id = "hf-internal-testing/tiny-detr-mobilenetsv3-panoptic"

model = AutoModelForImageSegmentation.from_pretrained(model_id)
image_processor = AutoImageProcessor.from_pretrained(model_id)
image_segmenter = pipeline(
task="image-segmentation",
model=model,
image_processor=image_processor,
)
with tempfile.TemporaryDirectory() as tmpdirname:
image_segmenter.save_pretrained(tmpdirname)
pipeline(task="image-segmentation", model=tmpdirname)

0 comments on commit 4dd52ec

Please sign in to comment.