Skip to content

Commit

Permalink
Update to PyTorch 1.0. Fixes #5
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 31, 2019
1 parent e3b1e8e commit a2051f0
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 45 deletions.
4 changes: 3 additions & 1 deletion convert_from_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _convert_bn(k):
aux = True
add = 'moving_var'
else:
assert False
assert False, 'Unknown key: %s' % k
return aux, add


Expand All @@ -38,6 +38,8 @@ def convert_from_mxnet(model, checkpoint_prefix, debug=False):
k = state_key.split('.')
aux = False
mxnet_key = ''
if k[-1] == 'num_batches_tracked':
continue
if k[0] == 'features':
if k[1] == 'conv1_1':
# input block
Expand Down
30 changes: 15 additions & 15 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,21 @@ def main():
batch_time = AverageMeter()
end = time.time()
top5_ids = []
for batch_idx, (input, _) in enumerate(loader):
input = input.cuda()
input_var = autograd.Variable(input, volatile=True)
labels = model(input_var)
top5 = labels.topk(5)[1]
top5_ids.append(top5.data.cpu().numpy())

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if batch_idx % args.print_freq == 0:
print('Predict: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time))
with torch.no_grad():
for batch_idx, (input, _) in enumerate(loader):
input = input.cuda()
labels = model(input)
top5 = labels.topk(5)[1]
top5_ids.append(top5.cpu().numpy())

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if batch_idx % args.print_freq == 0:
print('Predict: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time))

top5_ids = np.concatenate(top5_ids, axis=0).squeeze()

Expand Down
2 changes: 1 addition & 1 deletion model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_transforms_eval(model_name, img_size=224, crop_pct=None):
std=[0.229, 0.224, 0.225])

return transforms.Compose([
transforms.Scale(scale_size, Image.BICUBIC),
transforms.Resize(scale_size, Image.BICUBIC),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
normalize])
55 changes: 27 additions & 28 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,35 +100,34 @@ def main():

# switch to evaluate mode
model.eval()

end = time.time()
for i, (input, target) in enumerate(loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True).cuda()
target_var = torch.autograd.Variable(target, volatile=True).cuda()

# compute output
output = model(input_var)
loss = criterion(output, target_var)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
with torch.no_grad():
for i, (input, target) in enumerate(loader):
target = target.cuda()
input = input.cuda()

# compute output
output = model(input)
loss = criterion(output, target)

# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
Expand Down

1 comment on commit a2051f0

@gaoyarui
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too hard,my cousins

Please sign in to comment.