-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
307 lines (233 loc) · 9.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Roughly based on https://blog.roboflow.com/how-to-train-segformer-on-a-custom-dataset-with-pytorch-lightning/
import argparse
import os
from torch.utils.data import Dataset, DataLoader
from transformers import SegformerFeatureExtractor
import pytorch_lightning as pl
from transformers import SegformerForSemanticSegmentation
from datasets import load_metric
import torch
from torch import nn
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from PIL import Image
import numpy as np
try:
torch.set_float32_matmul_precision('high')
except:
pass
MODEL_ID = "nvidia/segformer-b0-finetuned-ade-512-512"
class SemanticSegmentationDataset(Dataset):
"""Image (semantic) segmentation dataset."""
def __init__(self, root_dir, feature_extractor, use_cache=False):
self.root_dir = root_dir
self.feature_extractor = feature_extractor
self.classes_csv_file = os.path.join(self.root_dir, "_classes.csv")
with open(self.classes_csv_file, 'r') as fid:
data = [l.split(',') for i,l in enumerate(fid) if i !=0]
self.id2label = {x[0]:x[1] for x in data}
image_file_names = [f for f in os.listdir(self.root_dir) if '.jpg' in f]
mask_file_names = [f for f in os.listdir(self.root_dir) if '.png' in f]
self.images = sorted(image_file_names)
self.masks = sorted(mask_file_names)
self.use_cache = use_cache
self.cache = {}
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
if self.use_cache and idx in self.cache:
return self.cache[idx]
image = Image.open(os.path.join(self.root_dir, self.images[idx]))
segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))
encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")
for k,v in encoded_inputs.items():
encoded_inputs[k].squeeze_()
if self.use_cache:
self.cache[idx] = encoded_inputs
return encoded_inputs
class SegformerFinetuner(pl.LightningModule):
def __init__(self, id2label, train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=100):
super(SegformerFinetuner, self).__init__()
self.id2label = id2label
self.metrics_interval = metrics_interval
self.train_dl = train_dataloader
self.val_dl = val_dataloader
self.test_dl = test_dataloader
self.num_classes = len(id2label.keys())
self.label2id = {v:k for k,v in self.id2label.items()}
self.model = SegformerForSemanticSegmentation.from_pretrained(
MODEL_ID,
return_dict=False,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
ignore_mismatched_sizes=True,
)
self.train_mean_iou = load_metric("mean_iou", trust_remote_code=True)
self.val_mean_iou = load_metric("mean_iou", trust_remote_code=True)
self.test_mean_iou = load_metric("mean_iou", trust_remote_code=True)
def forward(self, images, masks):
outputs = self.model(pixel_values=images, labels=masks)
return(outputs)
def training_step(self, batch, batch_nb):
images, masks = batch['pixel_values'], batch['labels']
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.train_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy()
)
if batch_nb % self.metrics_interval == 0:
metrics = self.train_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
metrics = {'loss': loss, "mean_iou": metrics["mean_iou"], "mean_accuracy": metrics["mean_accuracy"]}
for k,v in metrics.items():
self.log(k,v)
return(metrics)
else:
return({'loss': loss})
def validation_step(self, batch, batch_nb):
images, masks = batch['pixel_values'], batch['labels']
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.val_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy()
)
return({'val_loss': loss})
def validation_epoch_end(self, outputs):
metrics = self.val_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
val_mean_iou = metrics["mean_iou"]
val_mean_accuracy = metrics["mean_accuracy"]
metrics = {"val_loss": avg_val_loss, "val_mean_iou":val_mean_iou, "val_mean_accuracy":val_mean_accuracy}
for k,v in metrics.items():
self.log(k,v)
return metrics
def test_step(self, batch, batch_nb):
images, masks = batch['pixel_values'], batch['labels']
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits,
size=masks.shape[-2:],
mode="bilinear",
align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.test_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy()
)
return({'test_loss': loss})
def test_epoch_end(self, outputs):
metrics = self.test_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
test_mean_iou = metrics["mean_iou"]
test_mean_accuracy = metrics["mean_accuracy"]
metrics = {"test_loss": avg_test_loss, "test_mean_iou":test_mean_iou, "test_mean_accuracy":test_mean_accuracy}
for k,v in metrics.items():
self.log(k,v)
return metrics
def configure_optimizers(self):
return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08)
def train_dataloader(self):
return self.train_dl
def val_dataloader(self):
return self.val_dl
def test_dataloader(self):
return self.test_dl
def main():
parser = argparse.ArgumentParser(description="Train SegFormer models for use with GeoDeep")
parser.add_argument(
"input",
type=str,
help="Path to Roboflow's dataset directory with segmentation masks"
)
parser.add_argument(
"--epochs", "-e",
type=int,
default=400,
help="Max epochs"
)
parser.add_argument(
"--batch-size", "-b",
type=int,
default=8,
help="Batch size"
)
parser.add_argument(
"--workers", "-w",
type=int,
default=1,
help="Number of data loader workers"
)
parser.add_argument(
"--in-memory",
action="store_true",
default=False,
help="Load dataset in memory"
)
args = parser.parse_args()
feature_extractor = SegformerFeatureExtractor.from_pretrained(MODEL_ID)
feature_extractor.reduce_labels = False
feature_extractor.size = 128
train_dataset = SemanticSegmentationDataset(os.path.join(args.input, "train"), feature_extractor, use_cache=args.in_memory)
val_dataset = SemanticSegmentationDataset(os.path.join(args.input, "valid"), feature_extractor, use_cache=args.in_memory)
test_dataset = SemanticSegmentationDataset(os.path.join(args.input, "test"), feature_extractor, use_cache=args.in_memory)
batch_size = args.batch_size
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=args.workers, prefetch_factor=8, persistent_workers=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=args.workers, prefetch_factor=8, persistent_workers=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=args.workers, prefetch_factor=8, persistent_workers=True)
segformer_finetuner = SegformerFinetuner(
train_dataset.id2label,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
metrics_interval=10,
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=3,
verbose=False,
mode="min",
)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")
trainer = pl.Trainer(
gpus=1,
callbacks=[early_stop_callback, checkpoint_callback],
max_epochs=args.epochs,
val_check_interval=len(train_dataloader),
)
trainer.fit(segformer_finetuner)
segformer_finetuner
input_sample = torch.randn((1, 3, 512, 512))
segformer_finetuner.to_onnx("best.onnx", input_sample, export_params=True)
if __name__ == "__main__":
main()