diff --git a/examples/katib/mnist_with_summary.py b/examples/katib/mnist_with_summary.py index 24a99ffce..b733dfba7 100644 --- a/examples/katib/mnist_with_summary.py +++ b/examples/katib/mnist_with_summary.py @@ -7,7 +7,6 @@ from torchvision import datasets, transforms from torch.autograd import Variable from tensorboardX import SummaryWriter -writer = SummaryWriter('runs') # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', @@ -26,6 +25,8 @@ help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') +parser.add_argument('--dir', default='logs', metavar='L', + help='directory where summary logs are stored') args = parser.parse_args() args.cuda = not args.no_cuda and torch.cuda.is_available() @@ -72,7 +73,8 @@ def forward(self, x): if args.cuda: model.cuda() -print('Learning rate: {} Momentum: {}'.format(args.lr, args.momentum)) +print('Learning rate: {} Momentum: {} Logs dir: {}'.format(args.lr, args.momentum, args.dir)) +writer = SummaryWriter(args.dir) optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) def train(epoch):