-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
51 lines (42 loc) · 1.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import argparse
import jittor as jt
from jnerf.runner import Runner
from jnerf.utils.config import init_cfg
# jt.flags.gopt_disable=1
jt.flags.use_cuda = 1
def main():
assert jt.flags.cuda_archs[0] >= 61, "Failed: Sm arch version is too low! Sm arch version must not be lower than sm_61!"
parser = argparse.ArgumentParser(description="Jittor NGP Detection Training")
parser.add_argument(
"--config-file",
default="",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument(
"--task",
default="train",
help="train,render,test",
type=str,
)
parser.add_argument(
"--save_dir",
default="",
help="save path for rendering video",
type=str,
)
args = parser.parse_args()
assert args.task in ["train","test","render"],f"{args.task} not support, please choose [train, test, render]"
if args.config_file:
init_cfg(args.config_file)
runner = Runner()
if args.task == "train":
runner.train()
elif args.task == "test":
runner.test(True)
elif args.task == "render":
runner.render(True, args.save_dir)
if __name__ == "__main__":
main()