-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmenu.py
97 lines (84 loc) · 2.53 KB
/
menu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import List, Type
from pydantic import BaseModel
from auto_labeling_pipeline import mappings as mp
from auto_labeling_pipeline import models as mo
from auto_labeling_pipeline import task as t
class Option(BaseModel):
name: str
description: str
task: Type[t.Task]
model: Type[mo.RequestModel]
template: Type[mp.MappingTemplate]
class Config:
arbitrary_types_allowed = True
def to_dict(self):
return {
'name': self.name,
'description': self.description,
'schema': self.model.schema(),
'template': self.template().template
}
class Options:
options: List[Option] = []
@classmethod
def filter_by_task(cls, task_name: str) -> List[Option]:
task = t.TaskFactory.create(task_name)
return [option for option in cls.options if option.task == task or option.task == t.GenericTask]
@classmethod
def find(cls, option_name: str) -> Option:
for option in cls.options:
if option.name == option_name:
return option
raise ValueError('Option {} is not found.'.format(option_name))
@classmethod
def register(cls, task: Type[t.Task], model: Type[mo.RequestModel], template: Type[mp.MappingTemplate]):
schema = model.model_json_schema()
cls.options.append(
Option(
name=schema.get('title'), # type: ignore
description=schema.get('description'), # type: ignore
task=task,
model=model,
template=template
)
)
Options.register(
t.GenericTask,
mo.CustomRESTRequestModel,
mp.MappingTemplate
)
Options.register(
t.DocumentClassification,
mo.AmazonComprehendSentimentRequestModel,
mp.AmazonComprehendSentimentTemplate
)
Options.register(
t.SequenceLabeling,
mo.GCPEntitiesRequestModel,
mp.GCPEntitiesTemplate
)
Options.register(
t.SequenceLabeling,
mo.AmazonComprehendEntityRequestModel,
mp.AmazonComprehendEntityTemplate
)
Options.register(
t.SequenceLabeling,
mo.AmazonComprehendPIIEntityRequestModel,
mp.AmazonComprehendEntityTemplate
)
Options.register(
t.ImageClassification,
mo.GCPImageLabelDetectionRequestModel,
mp.GCPImageLabelDetectionTemplate
)
Options.register(
t.ImageClassification,
mo.AmazonRekognitionLabelDetectionRequestModel,
mp.AmazonRekognitionLabelDetectionTemplate
)
Options.register(
t.SpeechToText,
mo.GCPSpeechToTextRequestModel,
mp.GCPSpeechToTextTemplate
)