PyTorch Model Inference in Golang
- model - example torchscript module with example image and class to label mapping file
- golibtorch - C++ header and source files with CGO wrapper
- main.go - Driver program which executes the CGO wrapper with model and example input
Run following command to download and setup LibTorch (CPU version):
make deps
- Refer model/torchscript.py for conversion of torchvision model to TorchScript module
- TorchScript module is written as follows:
class YourModule(nn.Module):
def __init__(self):
# initialize the quantized model with pretrained weights
# load class to label dictionary
def forward(self, input):
# run forward pass and compute classes with its probabilities
# map classes to labels
# return result
- The example file uses ImageNet Classes for mapping imagenet class to its human readable label
- Run following command to create TorchScript module:
cd model
python torchscript.py save
- Run following command to load TorchScript module and run inference over sample image:
cd model
python torchscript.py run
Run the program with sample input:
make run
This project is licensed under MIT License.