-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtta.py
29 lines (26 loc) · 1.3 KB
/
tta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F
from torchvision import transforms as T
def ttaug(model, output_orig, img, mask):
transforms = []
transforms.append(T.Compose([T.RandomRotation((-30, 30)), T.Resize((224, 224))]))
transforms.append(T.RandomCrop((224, 224)))
transforms.append(T.Compose([T.RandomHorizontalFlip(1), T.Resize((224, 224))]))
transforms.append(T.Compose([T.RandomRotation((-30, 30)), T.RandomCrop((224, 224))]))
transforms.append(T.Compose([T.RandomRotation((-30, 30)), T.RandomHorizontalFlip(1), T.Resize((224, 224))]))
transforms.append(T.Compose([T.RandomHorizontalFlip(1), T.RandomCrop((224, 224))]))
transforms.append(T.Compose([T.RandomRotation((-30, 30)), T.RandomHorizontalFlip(1), T.RandomCrop((224, 224))]))
mask_size = torch.sum(mask).item()
outputs = [output_orig]
for transformer in transforms:
i = 0
while i < 10:
augmented_image = transformer(img)
augmented_mask = transformer(mask)
if torch.sum(augmented_mask.squeeze()).item() > mask_size-10:
model_output = F.softmax(model(augmented_image), dim=-1).data
outputs.append(model_output)
break
i += 1
outputs = torch.stack(outputs, dim=1)
return outputs