This repository was archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathtrainer.py
124 lines (107 loc) · 4.17 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import torch
import torch.distributed as dist
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.metric_logger import MetricLogger
from apex import amp
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
dist.reduce(all_losses, dst=0)
if dist.get_rank() == 0:
# only main process gets accumulated, so only divide by
# world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses
def do_train(
model,
data_loader,
optimizer,
scheduler,
checkpointer,
device,
checkpoint_period,
arguments,
):
logger = logging.getLogger("maskrcnn_benchmark.trainer")
logger.info("Start training")
meters = MetricLogger(delimiter=" ")
max_iter = len(data_loader)
start_iter = arguments["iteration"]
model.train()
start_training_time = time.time()
end = time.time()
for iteration, (images, targets, _) in enumerate(data_loader, start_iter):
if any(len(target) < 1 for target in targets):
# logger.error(f"Iteration={iteration + 1} || Image Ids used for training {_} || targets Length={[len(target) for target in targets]}" )
# continue
pass
data_time = time.time() - end
iteration = iteration + 1
arguments["iteration"] = iteration
scheduler.step()
images = images.to(device)
targets = [target.to(device) for target in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_loss_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
meters.update(loss=losses_reduced, **loss_dict_reduced)
optimizer.zero_grad()
# Note: If mixed precision is not used, this ends up doing nothing
# Otherwise apply loss scaling for mixed-precision recipe
with amp.scale_loss(losses, optimizer) as scaled_losses:
scaled_losses.backward()
optimizer.step()
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
"eta: {eta}",
"iter: {iter}",
"{meters}",
"lr: {lr:.6f}",
"max mem: {memory:.0f}",
]
).format(
eta=eta_string,
iter=iteration,
meters=str(meters),
lr=optimizer.param_groups[0]["lr"],
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
)
)
if iteration % checkpoint_period == 0:
checkpointer.save("model_{:07d}".format(iteration), **arguments)
if iteration == max_iter:
checkpointer.save("model_final", **arguments)
total_training_time = time.time() - start_training_time
total_time_str = str(datetime.timedelta(seconds=total_training_time))
logger.info(
"Total training time: {} ({:.4f} s / it)".format(
total_time_str, total_training_time / (max_iter)
)
)