Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xla.step context manager #7068

Merged
merged 7 commits into from
May 17, 2024
Merged

Add xla.step context manager #7068

merged 7 commits into from
May 17, 2024

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented May 15, 2024

See #6751

  • This implementation is intentionally minimal to start with. The main improvement compared to sync is that exceptions are handled sanely.
  • Update README example to use xla.step. Remove ParallelLoader because it mostly does not make a difference for MP, and we should keep our starting point as simple as possible.

@will-cromar will-cromar changed the title [WIP] Add xla.step context manager Add xla.step context manager May 16, 2024
@will-cromar will-cromar requested a review from JackCaoG May 16, 2024 20:09
@will-cromar will-cromar marked this pull request as ready for review May 16, 2024 20:09

# Create a DataLoader
dataset = TensorDataset(input_data, target_data)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does DataLoader don't take device as argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. In normal PyTorch, you have to move the data with tensor.to: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu



@contextlib.contextmanager
def step():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the reason I find step can be a bit confusing is that we don't call mark_step upon entering the step.

with xla.step():
  y = x + z

y += 1

step as a context kind of suggest execution will only cover what happened in side the context manger but that's actually not the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. This should either print a warning if there are pending operations, or just mark_step twice. What do you think is better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's try mark_step twice and benchmark it with one of the examples on resneto50 with fakedata.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to hold off on modifying the examples until we're running tests on them. Here's my patch:

--- a/examples/train_resnet_base.py
+++ b/examples/train_resnet_base.py
@@ -45,15 +45,16 @@ class TrainResNetBase():
     self.model.train()
     loader = itertools.islice(loader, self.num_steps)
     for step, (data, target) in enumerate(loader):
-      self.optimizer.zero_grad()
-      output = self.model(data)
-      loss = self.loss_fn(output, target)
-      loss.backward()
-      self.run_optimizer()
+      with torch_xla.step():
+        self.optimizer.zero_grad()
+        output = self.model(data)
+        loss = self.loss_fn(output, target)
+        loss.backward()
+        self.run_optimizer()
+
       tracker.add(self.batch_size)
       if step % 10 == 0:
-        xm.add_step_closure(
-            self._train_update, args=(step, loss, tracker, epoch))
+        self._train_update(step, loss, tracker, epoch)

Before:

epoch: 1, step: 290, loss: 6.608619213104248, rate: 1747.0911849087843
epoch: 1, step: 290, loss: 6.606635570526123, rate: 1747.0763868012214
epoch: 1, step: 290, loss: 6.618781566619873, rate: 1747.2648104487325
epoch: 1, step: 290, loss: 6.605813980102539, rate: 1746.9924093597208

After:

epoch: 1, step: 290, loss: 6.603261947631836, rate: 1752.4689284654187
epoch: 1, step: 290, loss: 6.607376575469971, rate: 1752.4377415557715
epoch: 1, step: 290, loss: 6.611710071563721, rate: 1752.2556378789855
epoch: 1, step: 290, loss: 6.638012886047363, rate: 1752.400066823619

@will-cromar will-cromar merged commit 3c59087 into master May 17, 2024
19 of 20 checks passed
zpcore pushed a commit that referenced this pull request May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants