-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add xla.step
context manager
#7068
Conversation
xla.step
context managerxla.step
context manager
|
||
# Create a DataLoader | ||
dataset = TensorDataset(input_data, target_data) | ||
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does DataLoader don't take device as argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. In normal PyTorch, you have to move the data with tensor.to
: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu
|
||
|
||
@contextlib.contextmanager | ||
def step(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the reason I find step
can be a bit confusing is that we don't call mark_step
upon entering the step
.
with xla.step():
y = x + z
y += 1
step
as a context kind of suggest execution will only cover what happened in side the context manger but that's actually not the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. This should either print a warning if there are pending operations, or just mark_step twice. What do you think is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's try mark_step
twice and benchmark it with one of the examples on resneto50 with fakedata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to hold off on modifying the examples until we're running tests on them. Here's my patch:
--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -45,15 +45,16 @@ class TrainResNetBase():
self.model.train()
loader = itertools.islice(loader, self.num_steps)
for step, (data, target) in enumerate(loader):
- self.optimizer.zero_grad()
- output = self.model(data)
- loss = self.loss_fn(output, target)
- loss.backward()
- self.run_optimizer()
+ with torch_xla.step():
+ self.optimizer.zero_grad()
+ output = self.model(data)
+ loss = self.loss_fn(output, target)
+ loss.backward()
+ self.run_optimizer()
+
tracker.add(self.batch_size)
if step % 10 == 0:
- xm.add_step_closure(
- self._train_update, args=(step, loss, tracker, epoch))
+ self._train_update(step, loss, tracker, epoch)
Before:
epoch: 1, step: 290, loss: 6.608619213104248, rate: 1747.0911849087843
epoch: 1, step: 290, loss: 6.606635570526123, rate: 1747.0763868012214
epoch: 1, step: 290, loss: 6.618781566619873, rate: 1747.2648104487325
epoch: 1, step: 290, loss: 6.605813980102539, rate: 1746.9924093597208
After:
epoch: 1, step: 290, loss: 6.603261947631836, rate: 1752.4689284654187
epoch: 1, step: 290, loss: 6.607376575469971, rate: 1752.4377415557715
epoch: 1, step: 290, loss: 6.611710071563721, rate: 1752.2556378789855
epoch: 1, step: 290, loss: 6.638012886047363, rate: 1752.400066823619
See #6751
sync
is that exceptions are handled sanely.xla.step
. RemoveParallelLoader
because it mostly does not make a difference for MP, and we should keep our starting point as simple as possible.