Skip to content

Commit

Permalink
Update python inference with new example
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Sep 28, 2023
1 parent 44d959b commit 5e33767
Showing 1 changed file with 38 additions and 8 deletions.
46 changes: 38 additions & 8 deletions examples/2_ResNet18/resnet_infer_python.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Load ResNet-18 saved to TorchScript and run inference with ones."""
"""Load ResNet-18 saved to TorchScript and run inference with an example image."""

import numpy as np
import torch


def deploy(saved_model, device, batch_size=1):
def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor:
"""
Load TorchScript ResNet-18 and run inference with Tensor of ones.
Load TorchScript ResNet-18 and run inference with Tensor from example image.
Parameters
----------
Expand All @@ -21,8 +22,13 @@ def deploy(saved_model, device, batch_size=1):
output : torch.Tensor
result of running inference on model with Tensor of ones
"""
transposed_shape = [224, 224, 3, 1]
precision = np.float32

input_tensor = torch.ones(batch_size, 3, 224, 224)
np_data = np.fromfile("data/image_tensor.dat", dtype=precision)
np_data = np_data.reshape(transposed_shape)
np_data = np_data.transpose()
input_tensor = torch.from_numpy(np_data)

if device == "cpu":
# Load saved TorchScript model
Expand All @@ -42,14 +48,38 @@ def deploy(saved_model, device, batch_size=1):
return output


def print_top_results(output: torch.Tensor) -> None:
"""Prints top 5 results
Parameters
----------
output: torch.Tensor
Output from ResNet-18.
"""
# Run a softmax to get probabilities
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Read ImageNet labels from text file
cats_filename = "data/categories.txt"
categories = np.genfromtxt(cats_filename, dtype=str, delimiter="\n")

# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
print("\nTop 5 results:\n")
for i in range(top5_prob.size(0)):
cat_id = top5_catid[i]
print(
f"{categories[cat_id]} (id={cat_id}): probability = {top5_prob[i].item()}"
)


if __name__ == "__main__":
saved_model_file = "saved_resnet18_model_cpu.pt"

device_to_run = "cpu"
# device = "cuda"
# device_to_run = "cuda"

batch_size_to_run = 1

result = deploy(saved_model_file, device_to_run, batch_size_to_run)

print(result[:, 0:5])
output = deploy(saved_model_file, device_to_run, batch_size_to_run)
print_top_results(output)

0 comments on commit 5e33767

Please sign in to comment.