-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinstance_segmentation.py
78 lines (58 loc) · 1.75 KB
/
instance_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Zero Shot Instance Segmentation.
| Copyright 2017-2023, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from importlib.util import find_spec
import os
from fiftyone.core.utils import add_sys_path
import fiftyone.zoo as foz
SAM_ARCHS = ("ViT-B", "ViT-H", "ViT-L")
SAM_MODELS = [("", SA) for SA in SAM_ARCHS]
def SAM_activator():
return True
def build_instance_segmentation_models_dict():
sms = {}
if SAM_activator():
sms["SAM"] = {
"activator": SAM_activator,
"model": "N/A",
"name": "SAM",
"submodels": SAM_MODELS,
}
return sms
INSTANCE_SEGMENTATION_MODELS = build_instance_segmentation_models_dict()
def _get_segmentation_model(architecture):
zoo_model_name = (
"segment-anything-" + architecture.lower().replace("-", "") + "-torch"
)
return foz.load_zoo_model(zoo_model_name)
def run_zero_shot_instance_segmentation(
dataset,
model_name,
label_field,
categories,
pretrained=None,
architecture=None,
**kwargs
):
with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
# pylint: disable=no-name-in-module,import-error
from detection import run_zero_shot_detection
det_model_name, _ = model_name.split(" + ")
det_pretrained, _ = pretrained.split(" + ")
if det_pretrained == "":
det_pretrained = None
_, seg_architecture = architecture.split(" + ")
run_zero_shot_detection(
dataset,
det_model_name,
label_field,
categories,
pretrained=det_pretrained,
**kwargs
)
seg_model = _get_segmentation_model(seg_architecture)
dataset.apply_model(
seg_model, label_field=label_field, prompt_field=label_field
)