From 5e3376711bfa1803c1d7c46930bfb2f1d5a24851 Mon Sep 17 00:00:00 2001 From: ElliottKasoar Date: Thu, 28 Sep 2023 13:05:10 +0100 Subject: [PATCH] Update python inference with new example --- examples/2_ResNet18/resnet_infer_python.py | 46 ++++++++++++++++++---- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/examples/2_ResNet18/resnet_infer_python.py b/examples/2_ResNet18/resnet_infer_python.py index a1dd4993..b773e901 100644 --- a/examples/2_ResNet18/resnet_infer_python.py +++ b/examples/2_ResNet18/resnet_infer_python.py @@ -1,11 +1,12 @@ -"""Load ResNet-18 saved to TorchScript and run inference with ones.""" +"""Load ResNet-18 saved to TorchScript and run inference with an example image.""" +import numpy as np import torch -def deploy(saved_model, device, batch_size=1): +def deploy(saved_model: str, device: str, batch_size: int = 1) -> torch.Tensor: """ - Load TorchScript ResNet-18 and run inference with Tensor of ones. + Load TorchScript ResNet-18 and run inference with Tensor from example image. Parameters ---------- @@ -21,8 +22,13 @@ def deploy(saved_model, device, batch_size=1): output : torch.Tensor result of running inference on model with Tensor of ones """ + transposed_shape = [224, 224, 3, 1] + precision = np.float32 - input_tensor = torch.ones(batch_size, 3, 224, 224) + np_data = np.fromfile("data/image_tensor.dat", dtype=precision) + np_data = np_data.reshape(transposed_shape) + np_data = np_data.transpose() + input_tensor = torch.from_numpy(np_data) if device == "cpu": # Load saved TorchScript model @@ -42,14 +48,38 @@ def deploy(saved_model, device, batch_size=1): return output +def print_top_results(output: torch.Tensor) -> None: + """Prints top 5 results + + Parameters + ---------- + output: torch.Tensor + Output from ResNet-18. + """ + # Run a softmax to get probabilities + probabilities = torch.nn.functional.softmax(output[0], dim=0) + + # Read ImageNet labels from text file + cats_filename = "data/categories.txt" + categories = np.genfromtxt(cats_filename, dtype=str, delimiter="\n") + + # Show top categories per image + top5_prob, top5_catid = torch.topk(probabilities, 5) + print("\nTop 5 results:\n") + for i in range(top5_prob.size(0)): + cat_id = top5_catid[i] + print( + f"{categories[cat_id]} (id={cat_id}): probability = {top5_prob[i].item()}" + ) + + if __name__ == "__main__": saved_model_file = "saved_resnet18_model_cpu.pt" device_to_run = "cpu" - # device = "cuda" + # device_to_run = "cuda" batch_size_to_run = 1 - result = deploy(saved_model_file, device_to_run, batch_size_to_run) - - print(result[:, 0:5]) + output = deploy(saved_model_file, device_to_run, batch_size_to_run) + print_top_results(output)