Skip to content

Commit

Permalink
option to use pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Dec 9, 2020
1 parent 78b0e01 commit 1ec1d2b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions bonito/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter

from bonito.util import load_data, load_symbol, init, default_config, default_data
from bonito.util import load_data, load_model, load_symbol, init, default_config, default_data
from bonito.training import ChunkDataSet, load_state, train, test, func_scheduler, cosine_decay_schedule, CSVLogger

import toml
Expand Down Expand Up @@ -54,7 +54,11 @@ def main(args):
toml.dump({**config, **argsdict, **chunk_config}, open(os.path.join(workdir, 'config.toml'), 'w'))

print("[loading model]")
model = load_symbol(config, 'Model')(config)
if args.pretrained:
print("[using pretrained model {}]".format(args.pretrained))
model = load_model(args.pretrained, device, half=False)
else:
model = load_symbol(config, 'Model')(config)
optimizer = AdamW(model.parameters(), amsgrad=False, lr=args.lr)

last_epoch = load_state(workdir, args.device, model, optimizer, use_amp=args.amp)
Expand Down Expand Up @@ -127,4 +131,5 @@ def argparser():
parser.add_argument("--amp", action="store_true", default=False)
parser.add_argument("--multi-gpu", action="store_true", default=False)
parser.add_argument("-f", "--force", action="store_true", default=False)
parser.add_argument("--pretrained", default="")
return parser

0 comments on commit 1ec1d2b

Please sign in to comment.