Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 7338be2

Browse files
Miguel Varela Ramosfmassa
Miguel Varela Ramos
authored andcommitted
Save full configuration in output dir (#835)
* Merge branch 'master' of /home/braincreator/projects/maskrcnn-benchmark with conflicts. * update Dockerfile * save config in output dir * replace string format with os.path.join
1 parent d269847 commit 7338be2

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

maskrcnn_benchmark/utils/miscellaneous.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
22
import errno
33
import os
4+
from .comm import is_main_process
45

56

67
def mkdir(path):
@@ -9,3 +10,9 @@ def mkdir(path):
910
except OSError as e:
1011
if e.errno != errno.EEXIST:
1112
raise
13+
14+
15+
def save_config(cfg, path):
16+
if is_main_process():
17+
with open(path, 'w') as f:
18+
f.write(cfg.dump())

tools/train_net.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
2424
from maskrcnn_benchmark.utils.imports import import_file
2525
from maskrcnn_benchmark.utils.logger import setup_logger
26-
from maskrcnn_benchmark.utils.miscellaneous import mkdir
26+
from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config
2727

2828
# See if we can use apex.DistributedDataParallel instead of the torch default,
2929
# and enable mixed-precision via apex.amp
@@ -176,6 +176,11 @@ def main():
176176
logger.info(config_str)
177177
logger.info("Running with config:\n{}".format(cfg))
178178

179+
output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml')
180+
logger.info("Saving config into: {}".format(output_config_path))
181+
# save overloaded model config in the output directory
182+
save_config(cfg, output_config_path)
183+
179184
model = train(cfg, args.local_rank, args.distributed)
180185

181186
if not args.skip_test:

0 commit comments

Comments
 (0)