Skip to content

Commit

Permalink
Fix train loop in trainingyt.py (#2372)
Browse files Browse the repository at this point in the history
* refactored train loop in trainingyt.py, resolves issue #2230

* Simplified numpy function call, resolves issue #1038
  • Loading branch information
JoseLuisC99 authored Jun 1, 2023
1 parent 4673b14 commit d686b66
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
20 changes: 12 additions & 8 deletions beginner_source/introyt/trainingyt.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,19 @@ def train_one_epoch(epoch_index, tb_writer):
model.train(True)
avg_loss = train_one_epoch(epoch_number, writer)

# We don't need gradients on to do reporting
model.train(False)


running_vloss = 0.0
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss
# Set the model to evaluation mode, disabling dropout and using population
# statistics for batch normalization.
model.eval()

# Disable gradient computation and reduce memory consumption.
with torch.no_grad():
for i, vdata in enumerate(validation_loader):
vinputs, vlabels = vdata
voutputs = model(vinputs)
vloss = loss_fn(voutputs, vlabels)
running_vloss += vloss

avg_vloss = running_vloss / (i + 1)
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
Expand Down
2 changes: 1 addition & 1 deletion intermediate_source/torchvision_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Let’s write a ``torch.utils.data.Dataset`` class for this dataset.
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
pos = np.nonzero(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
Expand Down

0 comments on commit d686b66

Please sign in to comment.