Skip to content

Commit

Permalink
Fixe Bug for Loading Non Default Model
Browse files Browse the repository at this point in the history
Fixes issue alankbi#122
  • Loading branch information
RoyiAvital authored Dec 10, 2022
1 parent c57b5fe commit a291bed
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions detecto/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def save(self, file):
torch.save(self._model.state_dict(), file)

@staticmethod
def load(file, classes):
def load(file, classes, model_name=DEFAULT):
"""Loads a model from a .pth file containing the model weights.
:param file: The path to the .pth file containing the saved model.
Expand All @@ -613,7 +613,7 @@ def load(file, classes):
>>> model = Model.load('model_weights.pth', ['ant', 'bee'])
"""

model = Model(classes)
model = Model(classes, model_name=model_name)
model._model.load_state_dict(torch.load(file, map_location=model._device))
return model

Expand Down

0 comments on commit a291bed

Please sign in to comment.