-
Notifications
You must be signed in to change notification settings - Fork 0
/
usage.py
152 lines (126 loc) · 4.28 KB
/
usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import sys
from typing import List
import torch
sys.path.append("../")
from tta_pytorch import FLIP_MODES, TYPES, Chain, Compose, Flip, Merger, Rescale, Resize, Rotate
tta_trans = Compose(
[
Rescale(
scales=[0.5],
image_mode="bilinear",
image_align_corners=False,
mask_mode="bilinear",
mask_align_corners=False,
),
Resize(
sizes=[128],
image_mode="bilinear",
image_align_corners=False,
mask_mode="bilinear",
mask_align_corners=False,
),
Flip(
flip_modes=[
FLIP_MODES.Horizontal,
FLIP_MODES.HorizontalVertical,
FLIP_MODES.Identity,
]
),
Rotate(
angles=[15, 45],
image_mode="bilinear",
mask_mode="bilinear",
),
],
verbose=True,
)
image = torch.randn(3, 1, 50, 50, dtype=torch.float32)
def base_usage():
tta_results = Merger()
for trans in tta_trans:
trans: Chain
aug_image = trans.do_image(image)
undo_image = trans.undo_image(aug_image)
tta_results.append(undo_image)
seg_results = tta_results.result
print(seg_results.shape)
return seg_results
def enhanced_usage1():
tta_results = Merger()
for trans in tta_trans:
trans: Chain
aug_images: List[torch.Tensor] = trans.do_all(inputs=[image], input_types=[TYPES.IMAGE])
undo_images = trans.undo_all(outputs=aug_images, output_types=[TYPES.MASK])
tta_results.append(undo_images[0])
seg_results = tta_results.result
print(seg_results.shape)
return seg_results
def enhanced_usage2():
# just a list with a merging function
tta_seg_merger = Merger(mode="mean")
tta_cls_merger = Merger(mode="mean")
tta_seg_merger.reset()
tta_cls_merger.reset()
for tran in tta_trans:
tran: Chain
aug_tensor = tran.do_image(image)
# simulate real data
mask = aug_tensor
label = torch.randn(3, 1000, dtype=torch.float32)
# for segmentation, [B,K,H,W]
undo_mask = tran.undo_image(mask)
tta_seg_merger.append(undo_mask)
# for classification, [B,K]
undo_label = tran.undo_label(label)
tta_cls_merger.append(undo_label)
seg_results = tta_seg_merger.result
seg_mask = seg_results.argmax(dim=1) # [B,H,W]
cls_results = tta_cls_merger.result
cls_score, cls_index = cls_results.max(dim=1) # [B], [B]
print(seg_mask.shape, cls_score.shape, cls_index.shape)
return seg_results
def enhanced_usage3():
tta_seg_results = []
tta_cls_results = []
for tran in tta_trans:
tran: Chain
aug_tensor = tran.do_image(image)
# simulate real data
mask = aug_tensor
label = torch.randn(3, 1000, dtype=torch.float32)
# for segmentation, [B,K,H,W]
undo_mask = tran.undo_image(mask)
tta_seg_results.append(undo_mask)
# for classification, [B,K]
undo_label = tran.undo_label(label)
tta_cls_results.append(undo_label)
seg_results = sum(tta_seg_results) / len(tta_seg_results)
seg_mask = seg_results.argmax(dim=1) # [B,H,W]
cls_results = sum(tta_cls_results) / len(tta_cls_results)
cls_score, cls_index = cls_results.max(dim=1) # [B], [B]
print(seg_mask.shape, cls_score.shape, cls_index.shape)
return seg_results
def enhanced_usage4():
@tta_trans.decorate(
input_infos={"image": TYPES.IMAGE},
output_infos={"mask": TYPES.MASK, "label": TYPES.LABEL},
merge_mode="mean",
)
def do_something(image=None):
label = torch.randn(3, 1000, dtype=torch.float32)
return {"mask": image, "label": label}
tta_results = do_something(image=image)
print({k: v.shape for k, v in tta_results.items()})
return tta_results["mask"]
if __name__ == "__main__":
results0 = base_usage()
results1 = enhanced_usage1()
results2 = enhanced_usage2()
results3 = enhanced_usage3()
results4 = enhanced_usage4()
assert torch.allclose(results0, results1)
assert torch.allclose(results0, results2)
assert torch.allclose(results0, results3)
assert torch.allclose(results0, results4)
print("All tests passed!")
print(tta_trans)