-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
120 lines (99 loc) · 2.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import json
from ultralytics import YOLO
def create_data_yaml():
"""
the data.yaml is required for the ultralytics to train and predict
we also save it with valohai alias to nicely access it on evaluation step
"""
import yaml
save_path = "/valohai/outputs/data.yaml"
data_yaml = {
"train": "/valohai/inputs/train",
"val": "/valohai/inputs/valid",
"test": "/valohai/inputs/test",
"nc": 1,
"names": ["ship"],
}
with open(save_path, "w") as file:
yaml.dump(data_yaml, file)
metadata_path = f"{save_path}.metadata.json"
print("metadata_path ", metadata_path)
with open(metadata_path, "w") as outfile:
json.dump({"valohai.alias": "data_yaml"}, outfile)
return save_path
def save_model_alias(project_path, alias="model-current-best"):
"""
YOLO saves the model weights to /valohai/outputs/train/weights,
We want to save the metadata for the best.pt to create Valohai alias
"""
metadata = {
"valohai.alias": alias, # creates or updates a Valohai data alias to point to this output file
}
metadata_path = f"{project_path}train/weights/best.pt.metadata.json"
with open(metadata_path, "w") as outfile:
json.dump(metadata, outfile)
def train_yolo(yolo_name, data_yaml, image_size, epochs, seed, batch_size, project):
model = YOLO(yolo_name)
# Training the model
model.train(
data=data_yaml,
epochs=epochs,
imgsz=image_size,
seed=seed,
batch=batch_size,
project=project,
)
save_model_alias(project)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Training parameters for ship aerial images",
)
parser.add_argument(
"--yolo_model_name",
type=str,
default="yolov8x.pt",
help="Model name",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="Number of training epochs",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="Batch size for training",
)
parser.add_argument(
"--image_size",
type=int,
default=768,
help="Size of the images",
)
parser.add_argument("--optimizer", type=str, default="SGD")
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility",
)
parser.add_argument(
"--project",
type=str,
default="/valohai/outputs",
help="Save path for the training logs",
)
args = parser.parse_args()
yaml_path = create_data_yaml()
train_yolo(
args.yolo_model_name,
yaml_path,
args.image_size,
args.epochs,
args.seed,
args.batch_size,
args.project,
)