Skip to content

Commit

Permalink
Fix different devices bug when moving model from GPU to CPU (#5110)
Browse files Browse the repository at this point in the history
* fix different devices bug

* extend _apply() instead of to() for a general fix

* Only apply if Detect() is last layer

Co-authored-by: Jebastin Nadar <njebastin10@gmail.com>

* Indent fix

* Add comment to yolo.py

* Add comment to common.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
jebastin-nadar and glenn-jocher committed Oct 10, 2021
1 parent 4a6dfff commit a0e1504
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
8 changes: 8 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ def autoshape(self):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
return self

@torch.no_grad()
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
Expand Down
9 changes: 9 additions & 0 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ def autoshape(self): # add AutoShape module
def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, Detect):
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
return self


def parse_model(d, ch): # model_dict, input_channels(3)
LOGGER.info('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
Expand Down

0 comments on commit a0e1504

Please sign in to comment.