-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathobject_detector.py
2418 lines (2195 loc) · 98 KB
/
object_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import copy
import os
import os.path as osp
import functools
import numpy as np
import paddle
from paddle.static import InputSpec
import paddlers
import paddlers.models.ppdet as ppdet
from paddlers.models.ppdet.modeling.proposal_generator.target_layer import BBoxAssigner, MaskAssigner
from paddlers.transforms import decode_image, construct_sample
from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH, Resize, Pad
from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget, BatchPadRGT, BatchNormalizeImage
from paddlers.models.ppdet.optimizer import ModelEMA
import paddlers.utils.logging as logging
from paddlers.utils.checkpoint import det_pretrain_weights_dict
from .base import BaseModel
from .utils.det_metrics import VOCMetric, COCOMetric, RBoxMetric
__all__ = [
"YOLOv3",
"FasterRCNN",
"PPYOLO",
"PPYOLOTiny",
"PPYOLOv2",
"MaskRCNN",
"FCOSR",
"PPYOLOE_R",
]
# TODO: Prune and decoupling
class BaseDetector(BaseModel):
data_fields = {
'coco':
{'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'is_crowd'},
'voc':
{'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'difficult'},
'rbox':
{'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class', 'gt_poly'},
}
supported_backbones = None
def __init__(self, model_name, num_classes=80, **params):
self.init_params.update(locals())
if 'with_net' in self.init_params:
del self.init_params['with_net']
super(BaseDetector, self).__init__('detector')
if not hasattr(ppdet.modeling, model_name):
raise ValueError("ERROR: There is no model named {}.".format(
model_name))
self.model_name = model_name
self.num_classes = num_classes
self.labels = None
if params.get('with_net', True):
params.pop('with_net', None)
self.net = self.build_net(**params)
def build_net(self, **params):
with paddle.utils.unique_name.guard():
net = ppdet.modeling.__dict__[self.model_name](**params)
return net
@classmethod
def set_data_fields(cls, data_name, data_fields):
cls.data_fields[data_name] = data_fields
def _is_backbone_weight(self):
target_backbone = ['ESNET_', 'CSPResNet_']
for b in target_backbone:
if b in self.backbone_name:
return True
return False
def _build_inference_net(self):
infer_net = self.net
infer_net.eval()
return infer_net
def _fix_transforms_shape(self, image_shape):
raise NotImplementedError("_fix_transforms_shape: not implemented!")
def _define_input_spec(self, image_shape):
input_spec = [{
"image": InputSpec(
shape=image_shape, name='image', dtype='float32'),
"im_shape": InputSpec(
shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
"scale_factor": InputSpec(
shape=[image_shape[0], 2], name='scale_factor', dtype='float32')
}]
return input_spec
def _check_backbone(self, backbone):
if backbone not in self.supported_backbones:
raise ValueError(
"backbone: {} is not supported. Please choose one of "
"{}.".format(backbone, self.supported_backbones))
def _check_image_shape(self, image_shape):
if len(image_shape) == 2:
image_shape = [1, 3] + image_shape
if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
raise ValueError(
"Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
format(image_shape[-2:]))
return image_shape
def _get_test_inputs(self, image_shape):
if image_shape is not None:
image_shape = self._check_image_shape(image_shape)
self._fix_transforms_shape(image_shape[-2:])
else:
image_shape = [None, 3, -1, -1]
self.fixed_input_shape = image_shape
return self._define_input_spec(image_shape)
def _get_backbone(self, backbone_name, **params):
# parse ResNetxx_xx
if backbone_name.startswith('ResNet'):
name = backbone_name.split('_')
fixed_kwargs = {}
depth = name[0]
fixed_kwargs['depth'] = int(depth[6:])
if len(name) > 1:
fixed_kwargs['variant'] = name[1][1]
backbone = getattr(ppdet.modeling, 'ResNet')
backbone = functools.partial(backbone, **fixed_kwargs)
else:
backbone = getattr(ppdet.modeling, backbone_name)
backbone_module = backbone(**params)
return backbone_module
def run(self, net, inputs, mode):
net_out = net(inputs)
if mode in ['train', 'eval']:
outputs = net_out
else:
outputs = dict()
for key in net_out:
outputs[key] = net_out[key].numpy()
return outputs
def default_optimizer(self,
parameters,
learning_rate,
warmup_steps,
warmup_start_lr,
lr_decay_epochs,
lr_decay_gamma,
num_steps_each_epoch,
reg_coeff=1e-04,
scheduler='Piecewise',
cosine_decay_num_epochs=1000,
clip_grad_by_norm=None,
num_epochs=None):
if scheduler.lower() == 'piecewise':
if warmup_steps > 0 and warmup_steps > lr_decay_epochs[
0] * num_steps_each_epoch:
logging.error(
"In function train(), parameters must satisfy: "
"warmup_steps <= lr_decay_epochs[0] * num_samples_in_train_dataset. "
"See this doc for more information: "
"https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
exit=False)
logging.error(
"Either `warmup_steps` be less than {} or lr_decay_epochs[0] be greater than {} "
"must be satisfied, please modify 'warmup_steps' or 'lr_decay_epochs' in train function".
format(lr_decay_epochs[0] * num_steps_each_epoch,
warmup_steps // num_steps_each_epoch),
exit=True)
boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
values = [(lr_decay_gamma**i) * learning_rate
for i in range(len(lr_decay_epochs) + 1)]
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
elif scheduler.lower() == 'cosine':
if num_epochs is None:
logging.error(
"`num_epochs` must be set while using cosine annealing decay scheduler, but received {}".
format(num_epochs),
exit=False)
if warmup_steps > 0 and warmup_steps > num_epochs * num_steps_each_epoch:
logging.error(
"In function train(), parameters must satisfy: "
"warmup_steps <= num_epochs * num_samples_in_train_dataset. "
"See this doc for more information: "
"https://github.com/PaddlePaddle/PaddleRS/blob/develop/docs/parameters.md",
exit=False)
logging.error(
"`warmup_steps` must be less than the total number of steps({}), "
"please modify 'num_epochs' or 'warmup_steps' in train function".
format(num_epochs * num_steps_each_epoch),
exit=True)
T_max = cosine_decay_num_epochs * num_steps_each_epoch - warmup_steps
scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=learning_rate,
T_max=T_max,
eta_min=0.0,
last_epoch=-1)
else:
logging.error(
"Invalid learning rate scheduler: {}!".format(scheduler),
exit=True)
if warmup_steps > 0:
scheduler = paddle.optimizer.lr.LinearWarmup(
learning_rate=scheduler,
warmup_steps=warmup_steps,
start_lr=warmup_start_lr,
end_lr=learning_rate)
if clip_grad_by_norm is not None:
grad_clip = paddle.nn.ClipGradByGlobalNorm(
clip_norm=clip_grad_by_norm)
else:
grad_clip = None
optimizer = paddle.optimizer.Momentum(
scheduler,
momentum=.9,
weight_decay=paddle.regularizer.L2Decay(coeff=reg_coeff),
parameters=parameters,
grad_clip=grad_clip)
return optimizer
def train(self,
num_epochs,
train_dataset,
train_batch_size=64,
eval_dataset=None,
optimizer=None,
save_interval_epochs=1,
log_interval_steps=10,
save_dir='output',
pretrain_weights='IMAGENET',
learning_rate=.001,
warmup_steps=0,
warmup_start_lr=0.0,
scheduler='Piecewise',
lr_decay_epochs=(216, 243),
lr_decay_gamma=0.1,
cosine_decay_num_epochs=1000,
metric=None,
use_ema=False,
early_stop=False,
early_stop_patience=5,
use_vdl=True,
clip_grad_by_norm=None,
reg_coeff=1e-4,
resume_checkpoint=None,
precision='fp32',
amp_level='O1',
custom_white_list=None,
custom_black_list=None):
"""
Train the model.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset|None, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer|None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
pretrain_weights (str|None, optional): None or name/path of pretrained
weights. If None, no pretrained weights will be loaded.
Defaults to 'IMAGENET'.
learning_rate (float, optional): Learning rate for training. Defaults to .001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
scheduler (str, optional): Learning rate scheduler used for training. If None,
a default scheduler will be used. Default to None.
lr_decay_epochs (list|tuple, optional): Epoch milestones for learning
rate decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
cosine_decay_num_epochs (int, optional): Parameter to determine the annealing
cycle when a cosine annealing learning rate scheduler is used.
Defaults to 1000.
metric (str|None, optional): Evaluation metric. Choices are
{'VOC', 'COCO', 'RBOX', None}. If None, determine the metric according to
the dataset format. Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average
strategy. Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl(bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
clip_grad_by_norm (float, optional): Maximum global norm for gradient clipping.
Default to None.
reg_coeff (float, optional): Coefficient for L2 weight decay regularization.
Default to 1e-4.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
training from. If None, no training checkpoint will be resumed. At most
Aone of `resume_checkpoint` and `pretrain_weights` can be set simultaneously.
Defaults to None.
precision (str, optional): Use AMP (auto mixed precision) training if `precision`
is set to 'fp16'. Defaults to 'fp32'.
amp_level (str, optional): Auto mixed precision level. Accepted values are 'O1'
and 'O2': At O1 level, the input data type of each operator will be casted
according to a white list and a black list. At O2 level, all parameters and
input data will be casted to FP16, except those for the operators in the black
list, those without the support for FP16 kernel, and those for the batchnorm
layers. Defaults to 'O1'.
custom_white_list(set|list|tuple|None, optional): Custom white list to use when
`amp_level` is set to 'O1'. Defaults to None.
custom_black_list(set|list|tuple|None, optional): Custom black list to use in AMP
training. Defaults to None.
"""
if precision != 'fp32':
raise ValueError("Currently, {} does not support AMP training.".
format(self.__class__.__name__))
args = self._pre_train(locals())
args.pop('self')
return self._real_train(**args)
def _pre_train(self, in_args):
return in_args
def _real_train(
self, num_epochs, train_dataset, train_batch_size, eval_dataset,
optimizer, save_interval_epochs, log_interval_steps, save_dir,
pretrain_weights, learning_rate, warmup_steps, warmup_start_lr,
lr_decay_epochs, lr_decay_gamma, metric, use_ema, early_stop,
early_stop_patience, use_vdl, resume_checkpoint, scheduler,
cosine_decay_num_epochs, clip_grad_by_norm, reg_coeff, precision,
amp_level, custom_white_list, custom_black_list):
self.precision = precision
self.amp_level = amp_level
self.custom_white_list = custom_white_list
self.custom_black_list = custom_black_list
if self.status == 'Infer':
logging.error(
"Exported inference model does not support training.",
exit=True)
if pretrain_weights is not None and resume_checkpoint is not None:
logging.error(
"`pretrain_weights` and `resume_checkpoint` cannot be set simultaneously.",
exit=True)
if metric is None:
if eval_dataset.__class__.__name__ == 'VOCDetDataset':
self.metric = 'voc'
elif eval_dataset.__class__.__name__ == 'COCODetDataset':
self.metric = 'coco'
else:
self.metric = metric.lower()
assert self.metric in ['coco', 'voc', 'rbox'], \
"Evaluation metric {} is not supported. Please choose from 'COCO', 'VOC' and 'RBOX'"
train_dataset.data_fields = self.data_fields[self.metric]
self.labels = train_dataset.labels
self.num_max_boxes = train_dataset.num_max_boxes
train_batch_transforms = self._compose_batch_transforms(
'train', train_dataset.batch_transforms)
train_dataset.build_collate_fn(train_batch_transforms,
self._default_collate_fn)
# Build optimizer if not defined
if optimizer is None:
num_steps_each_epoch = len(train_dataset) // train_batch_size
self.optimizer = self.default_optimizer(
scheduler=scheduler,
parameters=self.net.parameters(),
learning_rate=learning_rate,
warmup_steps=warmup_steps,
warmup_start_lr=warmup_start_lr,
lr_decay_epochs=lr_decay_epochs,
lr_decay_gamma=lr_decay_gamma,
num_steps_each_epoch=num_steps_each_epoch,
num_epochs=num_epochs,
clip_grad_by_norm=clip_grad_by_norm,
cosine_decay_num_epochs=cosine_decay_num_epochs,
reg_coeff=reg_coeff, )
else:
self.optimizer = optimizer
# Initiate weights
if pretrain_weights is not None:
if not osp.exists(pretrain_weights):
key = '_'.join([self.model_name, self.backbone_name])
if key not in det_pretrain_weights_dict:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = None
elif pretrain_weights not in det_pretrain_weights_dict[key]:
logging.warning(
"Path of pretrained weights ('{}') does not exist!".
format(pretrain_weights))
pretrain_weights = det_pretrain_weights_dict[key][0]
logging.warning(
"`pretrain_weights` is forcibly set to '{}'. "
"If you don't want to use pretrained weights, "
"please set `pretrain_weights` to None.".format(
pretrain_weights))
else:
if osp.splitext(pretrain_weights)[-1] != '.pdparams':
logging.error(
"Invalid pretrained weights. Please specify a .pdparams file.",
exit=True)
pretrained_dir = osp.join(save_dir, 'pretrain')
self.initialize_net(
pretrain_weights=pretrain_weights,
save_dir=pretrained_dir,
resume_checkpoint=resume_checkpoint,
is_backbone_weights=pretrain_weights == 'IMAGENET' and
self._is_backbone_weight())
if use_ema:
ema = ModelEMA(model=self.net, decay=.9998)
else:
ema = None
# Start train loop
self.train_loop(
num_epochs=num_epochs,
train_dataset=train_dataset,
train_batch_size=train_batch_size,
eval_dataset=eval_dataset,
save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps,
save_dir=save_dir,
ema=ema,
early_stop=early_stop,
early_stop_patience=early_stop_patience,
use_vdl=use_vdl)
def _default_collate_fn(self, dataset):
def _collate_fn(batch):
# We drop `trans_info` as it is not required in detection tasks
samples = [s[0] for s in batch]
return dataset.batch_transforms(samples)
return _collate_fn
def _default_batch_transforms(self, mode):
raise NotImplementedError
def _filter_batch_transforms(self, defaults, targets):
# TODO: Warning message
if targets is None:
return defaults
target_types = [type(i) for i in targets]
filtered = [i for i in defaults if type(i) not in target_types]
return filtered
def _compose_batch_transforms(self, mode, batch_transforms):
defaults = self._default_batch_transforms(mode)
out = []
if isinstance(batch_transforms, BatchCompose):
batch_transforms = batch_transforms.batch_transforms
if batch_transforms is not None:
out.extend(batch_transforms)
filtered = self._filter_batch_transforms(defaults.batch_transforms,
batch_transforms)
out.extend(filtered)
return BatchCompose(out, collate_batch=defaults.collate_batch)
def quant_aware_train(self,
num_epochs,
train_dataset,
train_batch_size=64,
eval_dataset=None,
optimizer=None,
save_interval_epochs=1,
log_interval_steps=10,
save_dir='output',
learning_rate=.00001,
warmup_steps=0,
warmup_start_lr=0.0,
lr_decay_epochs=(216, 243),
lr_decay_gamma=0.1,
metric=None,
use_ema=False,
early_stop=False,
early_stop_patience=5,
use_vdl=True,
resume_checkpoint=None,
quant_config=None):
"""
Quantization-aware training.
Args:
num_epochs (int): Number of epochs.
train_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Training dataset.
train_batch_size (int, optional): Total batch size among all cards used in
training. Defaults to 64.
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset|None, optional):
Evaluation dataset. If None, the model will not be evaluated during training
process. Defaults to None.
optimizer (paddle.optimizer.Optimizer or None, optional): Optimizer used for
training. If None, a default optimizer will be used. Defaults to None.
save_interval_epochs (int, optional): Epoch interval for saving the model.
Defaults to 1.
log_interval_steps (int, optional): Step interval for printing training
information. Defaults to 10.
save_dir (str, optional): Directory to save the model. Defaults to 'output'.
learning_rate (float, optional): Learning rate for training.
Defaults to .00001.
warmup_steps (int, optional): Number of steps of warm-up training.
Defaults to 0.
warmup_start_lr (float, optional): Start learning rate of warm-up training.
Defaults to 0..
lr_decay_epochs (list or tuple, optional): Epoch milestones for learning rate
decay. Defaults to (216, 243).
lr_decay_gamma (float, optional): Gamma coefficient of learning rate decay.
Defaults to .1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
use_ema (bool, optional): Whether to use exponential moving average strategy.
Defaults to False.
early_stop (bool, optional): Whether to adopt early stop strategy.
Defaults to False.
early_stop_patience (int, optional): Early stop patience. Defaults to 5.
use_vdl (bool, optional): Whether to use VisualDL to monitor the training
process. Defaults to True.
quant_config (dict or None, optional): Quantization configuration. If None,
a default rule of thumb configuration will be used. Defaults to None.
resume_checkpoint (str|None, optional): Path of the checkpoint to resume
quantization-aware training from. If None, no training checkpoint will
be resumed. Defaults to None.
"""
self._prepare_qat(quant_config)
self.train(
num_epochs=num_epochs,
train_dataset=train_dataset,
train_batch_size=train_batch_size,
eval_dataset=eval_dataset,
optimizer=optimizer,
save_interval_epochs=save_interval_epochs,
log_interval_steps=log_interval_steps,
save_dir=save_dir,
pretrain_weights=None,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
warmup_start_lr=warmup_start_lr,
lr_decay_epochs=lr_decay_epochs,
lr_decay_gamma=lr_decay_gamma,
metric=metric,
use_ema=use_ema,
early_stop=early_stop,
early_stop_patience=early_stop_patience,
use_vdl=use_vdl,
resume_checkpoint=resume_checkpoint)
def evaluate(self,
eval_dataset,
batch_size=1,
metric=None,
return_details=False):
"""
Evaluate the model.
Args:
eval_dataset (paddlers.datasets.COCODetDataset|paddlers.datasets.VOCDetDataset):
Evaluation dataset.
batch_size (int, optional): Total batch size among all cards used for
evaluation. Defaults to 1.
metric (str|None, optional): Evaluation metric. Choices are {'VOC', 'COCO', None}.
If None, determine the metric according to the dataset format.
Defaults to None.
return_details (bool, optional): Whether to return evaluation details.
Defaults to False.
Returns:
If `return_details` is False, return collections.OrderedDict with key-value pairs:
{"bbox_mmap": mean average precision (0.50, 11point)}.
"""
if metric is None:
if not hasattr(self, 'metric'):
if eval_dataset.__class__.__name__ == 'VOCDetDataset':
self.metric = 'voc'
elif eval_dataset.__class__.__name__ == 'COCODetDataset':
self.metric = 'coco'
else:
self.metric = metric.lower()
assert self.metric.lower() in ['coco', 'voc', 'rbox'], \
"Evaluation metric {} is not supported. Please choose from 'COCO' and 'VOC'."
eval_dataset.data_fields = self.data_fields[self.metric]
eval_batch_transforms = self._compose_batch_transforms(
'eval', eval_dataset.batch_transforms)
eval_dataset.build_collate_fn(eval_batch_transforms,
self._default_collate_fn)
self._check_transforms(eval_dataset.transforms)
self.net.eval()
nranks = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
if nranks > 1:
# Initialize parallel environment if not done.
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
if batch_size > 1:
logging.warning(
"Detector only supports single card evaluation with batch_size=1 "
"during evaluation, so batch_size is forcibly set to 1.")
batch_size = 1
if nranks < 2 or local_rank == 0:
self.eval_data_loader = self.build_data_loader(
eval_dataset,
batch_size=batch_size,
mode='eval',
collate_fn=eval_dataset.collate_fn)
is_bbox_normalized = False
if hasattr(eval_dataset, 'batch_transforms'):
is_bbox_normalized = any(
isinstance(t, _NormalizeBox)
for t in eval_dataset.batch_transforms.batch_transforms)
if self.metric == 'voc':
eval_metric = VOCMetric(
labels=eval_dataset.labels,
coco_gt=copy.deepcopy(eval_dataset.coco_gt),
is_bbox_normalized=is_bbox_normalized,
classwise=False)
elif self.metric == 'coco':
eval_metric = COCOMetric(
coco_gt=copy.deepcopy(eval_dataset.coco_gt),
classwise=False)
else:
assert hasattr(eval_dataset, 'get_anno_path')
eval_metric = RBoxMetric(
anno_file=eval_dataset.get_anno_path(), classwise=False)
scores = collections.OrderedDict()
logging.info(
"Start to evaluate (total_samples={}, total_steps={})...".
format(eval_dataset.num_samples, eval_dataset.num_samples))
with paddle.no_grad():
for step, data in enumerate(self.eval_data_loader):
if self.precision == 'fp16':
with paddle.amp.auto_cast(
level=self.amp_level,
enable=True,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list):
outputs = self.run(self.net, data, 'eval')
else:
outputs = self.run(self.net, data, 'eval')
eval_metric.update(data, outputs)
eval_metric.accumulate()
self.eval_details = eval_metric.details
scores.update(eval_metric.get())
eval_metric.reset()
if return_details:
return scores, self.eval_details
return scores
@paddle.no_grad()
def predict(self, img_file, transforms=None):
"""
Do inference.
Args:
img_file (list[np.ndarray|str] | str | np.ndarray): Image path or decoded
image data, which also could constitute a list, meaning all images to be
predicted as a mini-batch.
transforms (paddlers.transforms.Compose|None, optional): Transforms for
inputs. If None, the transforms for evaluation process will be used.
Defaults to None.
Returns:
If `img_file` is a string or np.array, the result is a list of dict with
the following key-value pairs:
category_id (int): Predicted category ID. 0 represents the first
category in the dataset, and so on.
category (str): Category name.
bbox (list): Bounding box in [x, y, w, h] format.
score (str): Confidence.
mask (dict): Only for instance segmentation task. Mask of the object in
RLE format.
If `img_file` is a list, the result is a list composed of list of dicts
with the above keys.
"""
if transforms is None and not hasattr(self, 'test_transforms'):
raise ValueError("transforms need to be defined, now is None.")
if transforms is None:
transforms = self.test_transforms
if isinstance(img_file, (str, np.ndarray)):
images = [img_file]
else:
images = img_file
batch_samples, _ = self.preprocess(images, transforms)
self.net.eval()
outputs = self.run(self.net, batch_samples, 'test')
prediction = self.postprocess(outputs)
if isinstance(img_file, (str, np.ndarray)):
prediction = prediction[0]
return prediction
def preprocess(self, images, transforms, to_tensor=True):
self._check_transforms(transforms)
batch_samples = list()
for im in images:
if isinstance(im, str):
im = decode_image(im, read_raw=True)
sample = construct_sample(image=im)
data = transforms(sample)
batch_samples.append(data[0])
batch_transforms = self._default_batch_transforms('test')
batch_samples = batch_transforms(batch_samples)
if to_tensor:
for k in batch_samples:
batch_samples[k] = paddle.to_tensor(batch_samples[k])
return batch_samples, None
def postprocess(self, batch_pred):
infer_result = {}
if 'bbox' in batch_pred:
bboxes = batch_pred['bbox']
bbox_nums = batch_pred['bbox_num']
det_res = []
k = 0
for i in range(len(bbox_nums)):
det_nums = bbox_nums[i]
for j in range(det_nums):
dt = bboxes[k]
k = k + 1
dt = dt.tolist()
if len(dt) == 6:
# Generic object detection
num_id, score, xmin, ymin, xmax, ymax = dt
w = xmax - xmin
h = ymax - ymin
bbox = [xmin, ymin, w, h]
elif len(dt) == 10:
# Rotated object detection
num_id, score, *pts = dt
bbox = list(pts)
else:
raise AssertionError
if int(num_id) < 0:
continue
category = self.labels[int(num_id)]
dt_res = {
'category_id': int(num_id),
'category': category,
'bbox': bbox,
'score': score
}
det_res.append(dt_res)
infer_result['bbox'] = det_res
if 'mask' in batch_pred:
masks = batch_pred['mask']
bboxes = batch_pred['bbox']
mask_nums = batch_pred['bbox_num']
seg_res = []
k = 0
for i in range(len(mask_nums)):
det_nums = mask_nums[i]
for j in range(det_nums):
mask = masks[k].astype(np.uint8)
score = float(bboxes[k][1])
label = int(bboxes[k][0])
k = k + 1
if label == -1:
continue
category = self.labels[int(label)]
sg_res = {
'category_id': int(label),
'category': category,
'mask': mask.astype('uint8'),
'score': score
}
seg_res.append(sg_res)
infer_result['mask'] = seg_res
bbox_num = batch_pred['bbox_num']
results = []
start = 0
for num in bbox_num:
end = start + num
curr_res = infer_result['bbox'][start:end]
if 'mask' in infer_result:
mask_res = infer_result['mask'][start:end]
for box, mask in zip(curr_res, mask_res):
box.update(mask)
results.append(curr_res)
start = end
return results
def get_pruning_info(self):
info = super().get_pruning_info()
info['pruner_inputs'] = {
k: v.tolist()
for k, v in info['pruner_inputs'][0].items()
}
return info
class PicoDet(BaseDetector):
supported_backbones = ('ESNet_s', 'ESNet_m', 'ESNet_l', 'LCNet',
'MobileNetV3', 'ResNet18_vd')
def __init__(self,
num_classes=80,
backbone='ESNet_m',
nms_score_threshold=.025,
nms_topk=1000,
nms_keep_topk=100,
nms_iou_threshold=.6,
**params):
self.init_params = locals()
self._check_backbone(backbone)
self.backbone_name = backbone
if params.get('with_net', True):
kwargs = {}
if backbone == 'ESNet_s':
backbone = self._get_backbone(
'ESNet',
scale=.75,
feature_maps=[4, 11, 14],
act="hard_swish",
channel_ratio=[
0.875, 0.5, 0.5, 0.5, 0.625, 0.5, 0.625, 0.5, 0.5, 0.5,
0.5, 0.5, 0.5
])
neck_out_channels = 96
head_num_convs = 2
elif backbone == 'ESNet_m':
backbone = self._get_backbone(
'ESNet',
scale=1.0,
feature_maps=[4, 11, 14],
act="hard_swish",
channel_ratio=[
0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
0.625, 1.0, 0.625, 0.75
])
neck_out_channels = 128
head_num_convs = 4
elif backbone == 'ESNet_l':
backbone = self._get_backbone(
'ESNet',
scale=1.25,
feature_maps=[4, 11, 14],
act="hard_swish",
channel_ratio=[
0.875, 0.5, 1.0, 0.625, 0.5, 0.75, 0.625, 0.625, 0.5,
0.625, 1.0, 0.625, 0.75
])
neck_out_channels = 160
head_num_convs = 4
elif backbone == 'LCNet':
kwargs['scale'] = 1.5
neck_out_channels = 128
head_num_convs = 4
elif backbone == 'MobileNetV3':
kwargs.update(
dict(
extra_block_filters=[], feature_maps=[7, 13, 16]))
neck_out_channels = 128
head_num_convs = 4
else:
kwargs.update(
dict(
return_idx=[1, 2, 3],
freeze_at=-1,
freeze_norm=False,
norm_decay=0.))
neck_out_channels = 128
head_num_convs = 4
if isinstance(backbone, str):
backbone = self._get_backbone(backbone, **kwargs)
neck = ppdet.modeling.CSPPAN(
in_channels=[i.channels for i in backbone.out_shape],
out_channels=neck_out_channels,
num_features=4,
num_csp_blocks=1,
use_depthwise=True)
head_conv_feat = ppdet.modeling.PicoFeat(
feat_in=neck_out_channels,
feat_out=neck_out_channels,
num_fpn_stride=4,
num_convs=head_num_convs,
norm_type='bn',
share_cls_reg=True, )
loss_class = ppdet.modeling.VarifocalLoss(
use_sigmoid=True, iou_weighted=True, loss_weight=1.0)
loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
assigner = ppdet.modeling.SimOTAAssigner(
candidate_topk=10, iou_weight=6, num_classes=num_classes)
nms = ppdet.modeling.MultiClassNMS(
nms_top_k=nms_topk,
keep_top_k=nms_keep_topk,
score_threshold=nms_score_threshold,
nms_threshold=nms_iou_threshold)
head = ppdet.modeling.PicoHead(
conv_feat=head_conv_feat,
num_classes=num_classes,
fpn_stride=[8, 16, 32, 64],
prior_prob=0.01,
reg_max=7,
cell_offset=.5,
loss_class=loss_class,
loss_dfl=loss_dfl,
loss_bbox=loss_bbox,
assigner=assigner,
feat_in_chan=neck_out_channels,
nms=nms)
params.update({
'backbone': backbone,
'neck': neck,
'head': head,
})
super(PicoDet, self).__init__(
model_name='PicoDet', num_classes=num_classes, **params)
def _default_batch_transforms(self, mode='train'):
batch_transforms = [_BatchPad(pad_to_stride=32)]
if mode == 'eval':
collate_batch = True
else:
collate_batch = False
batch_transforms = BatchCompose(
batch_transforms, collate_batch=collate_batch)
return batch_transforms
def _fix_transforms_shape(self, image_shape):
if getattr(self, 'test_transforms', None):
has_resize_op = False
resize_op_idx = -1
normalize_op_idx = len(self.test_transforms.transforms)
for idx, op in enumerate(self.test_transforms.transforms):
name = op.__class__.__name__
if name == 'Resize':
has_resize_op = True
resize_op_idx = idx
if name == 'Normalize':
normalize_op_idx = idx
if not has_resize_op:
self.test_transforms.transforms.insert(
normalize_op_idx,
Resize(
target_size=image_shape, interp='CUBIC'))
else: