From 1ec1d2bd63fcbb081487f4e06884c4dc6247bba2 Mon Sep 17 00:00:00 2001 From: Christopher Seymour Date: Wed, 9 Dec 2020 19:02:20 +0000 Subject: [PATCH] option to use pretrained weights --- bonito/cli/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/bonito/cli/train.py b/bonito/cli/train.py index fe893e31..026a5b2d 100644 --- a/bonito/cli/train.py +++ b/bonito/cli/train.py @@ -12,7 +12,7 @@ from argparse import ArgumentParser from argparse import ArgumentDefaultsHelpFormatter -from bonito.util import load_data, load_symbol, init, default_config, default_data +from bonito.util import load_data, load_model, load_symbol, init, default_config, default_data from bonito.training import ChunkDataSet, load_state, train, test, func_scheduler, cosine_decay_schedule, CSVLogger import toml @@ -54,7 +54,11 @@ def main(args): toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w')) print("[loading model]") - model = load_symbol(config, 'Model')(config) + if args.pretrained: + print("[using pretrained model {}]".format(args.pretrained)) + model = load_model(args.pretrained, device, half=False) + else: + model = load_symbol(config, 'Model')(config) optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr) last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp) @@ -127,4 +131,5 @@ def argparser(): parser.add_argument("--amp", action="store_true", default=False) parser.add_argument("--multi-gpu", action="store_true", default=False) parser.add_argument("-f", "--force", action="store_true", default=False) + parser.add_argument("--pretrained", default="") return parser