forked from eriklindernoren/PyTorch-YOLOv3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
24 lines (16 loc) · 764 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from pytorchyolo import train, test
import os
import argparse
import torch
if __name__ == "__main__":
if not os.path.isdir('trained_pruned_models'):
os.makedirs('trained_pruned_models')
parser = argparse.ArgumentParser(description="train pruned model")
parser.add_argument("-m", "--model", type=str, default="custom_pruned_models/cluster_prune_10.pth", help="Path to model")
parser.add_argument("-e", '--epochs', type = int, default=100, help='epochs')
args = parser.parse_args()
model = torch.load(args.model)
model = train.run(model, epochs=args.epochs)
test.run(model)
path_name = args.model.split('/')[-1].split('.')[0]
torch.save(model, 'trained_pruned_models/{}_{}.pth'.format(path_name, args.epochs))