-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsemantic_segmentation.py
134 lines (104 loc) · 3.58 KB
/
semantic_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Zero Shot Semantic Segmentation.
| Copyright 2017-2023, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from importlib.util import find_spec
from PIL import Image
import torch
import fiftyone as fo
from fiftyone.core.models import Model
class CLIPSegZeroShotModel(Model):
def __init__(self, config):
self.candidate_labels = config.get("categories", None)
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
self.processor = CLIPSegProcessor.from_pretrained(
"CIDAS/clipseg-rd64-refined"
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
"CIDAS/clipseg-rd64-refined"
)
@property
def media_type(self):
return "image"
def _predict(self, image):
inputs = self.processor(
text=self.candidate_labels,
images=[image] * len(self.candidate_labels),
padding="max_length",
return_tensors="pt",
)
with torch.no_grad():
outputs = self.model(**inputs)
preds = outputs.logits.unsqueeze(1)
# pylint: disable=no-member
mask = torch.argmax(preds, dim=0).squeeze().numpy()
return fo.Segmentation(mask=mask)
def predict(self, args):
image = Image.fromarray(args)
predictions = self._predict(image)
return predictions
def predict_all(self, samples, args):
pass
def CLIPSeg_activator():
return find_spec("transformers") is not None
class GroupViTZeroShotModel(Model):
def __init__(self, config):
cats = config.get("categories", None)
self.candidate_labels = [f"a photo of a {cat}" for cat in cats]
from transformers import AutoProcessor, GroupViTModel
self.processor = AutoProcessor.from_pretrained(
"nvidia/groupvit-gccyfcc"
)
self.model = GroupViTModel.from_pretrained("nvidia/groupvit-gccyfcc")
@property
def media_type(self):
return "image"
def _predict(self, image):
inputs = self.processor(
text=self.candidate_labels,
images=image,
padding="max_length",
return_tensors="pt",
)
with torch.no_grad():
outputs = self.model(**inputs, output_segmentation=True)
preds = outputs.segmentation_logits.squeeze()
# pylint: disable=no-member
mask = torch.argmax(preds, dim=0).numpy()
return fo.Segmentation(mask=mask)
def predict(self, args):
image = Image.fromarray(args)
image = image.resize((224, 224))
predictions = self._predict(image)
return predictions
def predict_all(self, samples, args):
pass
def GroupViT_activator():
return find_spec("transformers") is not None
SEMANTIC_SEGMENTATION_MODELS = {
"CLIPSeg": {
"activator": CLIPSeg_activator,
"model": CLIPSegZeroShotModel,
"name": "CLIPSeg",
},
"GroupViT": {
"activator": GroupViT_activator,
"model": GroupViTZeroShotModel,
"name": "GroupViT",
},
}
def _get_model(model_name, config):
return SEMANTIC_SEGMENTATION_MODELS[model_name]["model"](config)
def run_zero_shot_semantic_segmentation(
dataset, model_name, label_field, categories, **kwargs
):
if "other" not in categories:
categories.append("other")
config = {"categories": categories}
model = _get_model(model_name, config)
dataset.apply_model(model, label_field=label_field)
dataset.mask_targets[label_field] = {
i: label for i, label in enumerate(categories)
}
dataset.save()