From 756f41a0de84cd8d9a1c097cc1f62f4345e4f742 Mon Sep 17 00:00:00 2001 From: Pavel <60391448+pdumin@users.noreply.github.com> Date: Mon, 15 Jul 2024 03:31:52 +0400 Subject: [PATCH] [fix] Use torch.inference_mode inplace of torch.no_grad (#3188) --- examples/pytorch_track.py | 7 +++---- examples/pytorch_track_images.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/pytorch_track.py b/examples/pytorch_track.py index 7464295c8..3927356cb 100644 --- a/examples/pytorch_track.py +++ b/examples/pytorch_track.py @@ -12,8 +12,8 @@ # Initialize a new Run aim_run = Run() -# Device configuration -device = torch.device('cpu') +# moving model to gpu if available +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Hyper parameters num_epochs = 5 @@ -118,8 +118,7 @@ def forward(self, x): # Test the model -model.eval() -with torch.no_grad(): +with torch.inference_mode(): correct = 0 total = 0 for images, labels in test_loader: diff --git a/examples/pytorch_track_images.py b/examples/pytorch_track_images.py index 1c216c9ef..adb693a2f 100644 --- a/examples/pytorch_track_images.py +++ b/examples/pytorch_track_images.py @@ -13,8 +13,8 @@ # Initialize a new Run aim_run = Run() -# Device configuration -device = torch.device('cpu') +# moving model to gpu if available +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Hyper parameters num_epochs = 5 @@ -122,8 +122,7 @@ def forward(self, x): # Test the model -model.eval() -with torch.no_grad(): +with torch.inference_mode(): correct = 0 total = 0 for images, labels in tqdm(test_loader, total=len(test_loader)):