diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py index 5b72320d8..eb12e653f 100644 --- a/maskrcnn_benchmark/config/paths_catalog.py +++ b/maskrcnn_benchmark/config/paths_catalog.py @@ -33,6 +33,12 @@ class DatasetCatalog(object): "bdd100k_val": ( "bdd100k/images/100k/val", "bdd100k/labels/bdd100k_labels_images_val.json" ), + "bdd100k_det_train": ( + "bdd100k/images/100k/train", "bdd100k/labels/bdd100k_labels_images_det_coco_train.json" + ), + "bdd100k_det_val": ( + "bdd100k/images/100k/val", "bdd100k/labels/bdd100k_labels_images_det_coco_val.json" + ), "kitti_tracking_train": ( "kitti_tracking/train", "" ), @@ -46,7 +52,7 @@ def get(name): factory_names = { "coco": "COCODataset", "kitti": "KittiDataset", - "bdd100k": "Bdd100kDataset" + "bdd100k": "COCODataset" } for k in factory_names.keys(): if (k in name): diff --git a/maskrcnn_benchmark/engine/trainer.py b/maskrcnn_benchmark/engine/trainer.py index 4471fa410..457f264d1 100644 --- a/maskrcnn_benchmark/engine/trainer.py +++ b/maskrcnn_benchmark/engine/trainer.py @@ -74,6 +74,13 @@ def do_train( optimizer.zero_grad() losses.backward() + + accum_grad = 0 + for p in list(filter(lambda p: p.grad is not None, model.parameters())): + accum_grad += p.grad.data.norm(2).item() + + if iteration > 500 and accum_grad > 200: + torch.nn.utils.clip_grad_norm_(model.parameters(), 200) optimizer.step() batch_time = time.time() - end