Skip to content

Commit

Permalink
enhance(example): update example code with device converter (#1264)
Browse files Browse the repository at this point in the history
update example code with device converter
  • Loading branch information
tianweidut authored Sep 21, 2022
1 parent 8b6e394 commit c37fadd
Show file tree
Hide file tree
Showing 8 changed files with 9 additions and 57 deletions.
2 changes: 1 addition & 1 deletion example/PennFudanPed/pfp/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def cmp(self, ppl_result):

evaluator = make_coco_evaluator(annotations, iou_types=["bbox", "segm"])
for index, pred in pred_results:
evaluator.update({index: pred})
evaluator.update({index: {k: v.cpu() for k, v in pred.items()}})

evaluator.synchronize_between_processes()
evaluator.accumulate()
Expand Down
2 changes: 1 addition & 1 deletion example/cifar10/cifar/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _post(self, input):

def _load_model(self, device):
model = Net().to(device)
model.load_state_dict(torch.load(str(ROOTDIR / "models" / "cifar_net.pth")))
model.load_state_dict(torch.load(str(ROOTDIR / "models" / "cifar_net.pth"), map_location=device))
model.eval()
print("load cifar_net model, start to inference...")
return model
2 changes: 1 addition & 1 deletion example/mnist/mnist/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _post(self, input):

def _load_model(self, device):
model = Net().to(device)
model.load_state_dict(torch.load(str(ROOTDIR / "models/mnist_cnn.pt")))
model.load_state_dict(torch.load(str(ROOTDIR / "models/mnist_cnn.pt"), map_location=device))
model.eval()
print("load mnist model, start to inference...")
return model
Expand Down
48 changes: 0 additions & 48 deletions example/mnist/mnist/test.py

This file was deleted.

4 changes: 2 additions & 2 deletions example/nmt/nmt/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def _load_vocab(self):
def _load_encoder_model(self, device):
model = EncoderRNN(self.vocab.vin.n_words, self.hidden_size, device).to(device)

param = torch.load(_ROOT_DIR + "/models/encoder.pth", device)
param = torch.load(_ROOT_DIR + "/models/encoder.pth", map_location=device)
model.load_state_dict(param)
return model

def _load_decoder_model(self, device):
model = DecoderRNN(self.vocab.vout.n_words, self.hidden_size, device).to(device)
param = torch.load(_ROOT_DIR + "/models/decoder.pth", device)
param = torch.load(_ROOT_DIR + "/models/decoder.pth", map_location=device)
model.load_state_dict(param)
return model
4 changes: 2 additions & 2 deletions example/runtime/pytorch/runtime.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ configs:
docker:
image: ghcr.io/star-whale/runtime/pytorch
pip:
extra_index_url: https://pypi.doubanio.com/simple
extra_index_url: https://mirrors.bfsu.edu.cn/pypi/web/simple
index_url: https://pypi.tuna.tsinghua.edu.cn/simple
trusted_host:
- pypi.tuna.tsinghua.edu.cn
- pypi.doubanio.com
- mirrors.bfsu.edu.cn
dependencies:
- pip:
- Pillow
Expand Down
2 changes: 1 addition & 1 deletion example/speech_command/sc/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _post(self, input: torch.Tensor) -> t.Tuple[t.List[int], t.List[float]]:

def _load_model(self, device):
model = M5(n_input=1, n_output=len(ALL_LABELS))
model.load_state_dict(torch.load(str(ROOTDIR / "models/m5.pth")))
model.load_state_dict(torch.load(str(ROOTDIR / "models/m5.pth"), map_location=device))
model.to(device)
model.eval()
print("m5 model loaded, start to inference...")
Expand Down
2 changes: 1 addition & 1 deletion example/text_cls_AG_NEWS/tcan/ppl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def cmp(self, ppl_result):
def _load_model(self, device):
model_path = _ROOT_DIR + "/models/model.i"
model = TextClassificationModel(1308713, 32, _NUM_CLASSES).to(device)
model.load_state_dict(torch.load(model_path))
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model

Expand Down

0 comments on commit c37fadd

Please sign in to comment.