-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add looping example - how best to structure the code when calling inf…
…erence many times. (#169) * Add README with basic outline for the example. * Add simplenet, pt2ts, and requirements file. * Add instructions for running simplenet and saving to file. * Add 'bad' Fortran code for looping example. * Add 'good' Fortran code for looping example. * Add CMakeLists for building Looping example and update readme and fortran with instructions on how to do this. * Move autograd example to be number 6. * Add information about exercise 5 Looping to the examples README. * Update Documentation for Looping example to fix typos and clarify.
- Loading branch information
1 parent
1d4bb79
commit e0d8269
Showing
18 changed files
with
583 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||
#policy CMP0076 - target_sources source files are relative to file where target_sources is run | ||
cmake_policy (SET CMP0076 NEW) | ||
|
||
set(PROJECT_NAME LoopingExample) | ||
|
||
project(${PROJECT_NAME} LANGUAGES Fortran) | ||
|
||
# Build in Debug mode if not specified | ||
if(NOT CMAKE_BUILD_TYPE) | ||
set(CMAKE_BUILD_TYPE Debug CACHE STRING "" FORCE) | ||
endif() | ||
|
||
find_package(FTorch) | ||
message(STATUS "Building with Fortran PyTorch coupling") | ||
|
||
# Fortran example - bad | ||
add_executable(simplenet_infer_fortran_bad bad/simplenet_infer_fortran.f90) | ||
target_link_libraries(simplenet_infer_fortran_bad PRIVATE FTorch::ftorch) | ||
target_sources ( simplenet_infer_fortran_bad PRIVATE bad/fortran_ml_mod.f90 ) | ||
|
||
# Fortran example - good | ||
add_executable(simplenet_infer_fortran_good good/simplenet_infer_fortran.f90) | ||
target_link_libraries(simplenet_infer_fortran_good PRIVATE FTorch::ftorch) | ||
target_sources ( simplenet_infer_fortran_good PRIVATE good/fortran_ml_mod.f90 ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Example 5 - Looping | ||
|
||
So far many of the examples have been somewhat trivial, reading in a net and calling it | ||
once to demonstrate the inference process. | ||
|
||
In reality most applications that we use Fortran for will be performing many iterations | ||
of a process and calling the net multiple times. | ||
Loading in the net from file is computationally expensive, so we should do this only | ||
once, and then call the forward method on the loaded net as part of the iterative | ||
process. | ||
|
||
This example demonstrates the naive 'bad' approach and then the more efficient 'good' | ||
approach. It shows the suggested way to break down the FTorch code into initialisation, | ||
forward, and finalisation subprocesses, and allows users to time the different | ||
approaches to observe the significant performance difference. | ||
|
||
## Description | ||
|
||
We revisit SimpleNet from the first example that takes an input tensor of length 5 | ||
and multiplies it by two. | ||
This time we start by passing it the tensor `[1.0, 2.0, 3.0, 4.0]`, but then iterate | ||
10,000 times, each time incrementing each element by 1.0. | ||
We sum the results of each forward pass and print the final result. | ||
|
||
There are two folders `bad/` and `good/` that show two different approaches. | ||
|
||
The same `pt2ts.py` tool as in the previous examples is used to save the | ||
network to TorchScript. A `simplenet_infer_fortran.f90` file contains the main | ||
program that runs over the loop. A `fortran_ml_mod.f90` file contains a module with | ||
the FTorch code to load the TorchScript model, run it in inference mode, and clean up. | ||
|
||
### Bad | ||
|
||
We start with the 'bad' approach which takes the obvious approach, enclosing the code | ||
from example 1 in a loop. | ||
|
||
Examine the code in `bad/fortran_ml_mod.f90` to see how the subroutine `ml_routine()` | ||
creates a `torch_model` and `torch_tensor`s and reads in the net on every call before | ||
performing inference and then destroying them. | ||
|
||
### Good | ||
|
||
Now look at the 'good' approach. | ||
|
||
Examining the code in `good/fortran_ml_mod.f90` we see how there is an initialisation | ||
subroutine `ml_init()` that reads in the net from file, holding it as a module variable. | ||
There is then `ml_routine()` that maps the input and output data to `torch_tensor`s | ||
(also declared at module level) and performs the forward pass. | ||
Finally we have `ml_finalise()` that cleans up the net and tensors. | ||
|
||
Looking next at `good/simplenet_infer_fortran.f90` we see how the initialisation and | ||
finalisation routines are called once, before and after the main loop respectively, | ||
with only `ml_routine()` running the forward pass called from inside the loop. | ||
|
||
The benefits of this approach can be seen by comparing the time taken to run each | ||
version the code as detailed below. | ||
|
||
|
||
## Dependencies | ||
|
||
To run this example requires: | ||
|
||
- CMake | ||
- FTorch (installed as described in the main package) | ||
- Python 3 | ||
|
||
## Running | ||
|
||
To run this example install FTorch as described in the main documentation. Then | ||
from this directory create a virtual environment and install the necessary | ||
Python modules: | ||
``` | ||
python3 -m venv venv | ||
source venv/bin/activate | ||
pip install -r requirements.txt | ||
``` | ||
|
||
You can check everything is working by running `simplenet.py`: | ||
``` | ||
python3 simplenet.py | ||
``` | ||
This defines the network and runs it with input tensor [0.0, 1.0, 2.0, 3.0, 4.0] to | ||
produce the result: | ||
``` | ||
(tensor([0., 2., 4., 6., 8.])) | ||
``` | ||
|
||
To save the SimpleNet model to TorchScript, run the modified version of the | ||
`pt2ts.py` tool: | ||
``` | ||
python3 pt2ts.py | ||
``` | ||
which will generate `saved_simplenet_model.pt` - the TorchScript instance of | ||
the network and perform a quick sanity check that it can be read. | ||
|
||
At this point we no longer require Python, so can deactivate the virtual | ||
environment: | ||
``` | ||
deactivate | ||
``` | ||
|
||
Now we can build the Fortran codes. | ||
This is done using CMake s follows: | ||
``` | ||
mkdir build | ||
cd build | ||
cmake .. -DCMAKE_PREFIX_PATH=<path/to/your/installation/of/library/> -DCMAKE_BUILD_TYPE=Release | ||
cmake --build . | ||
``` | ||
|
||
(Note that the Fortran compiler can be chosen explicitly with the `-DCMAKE_Fortran_COMPILER` flag, | ||
and should match the compiler that was used to locally build FTorch.) | ||
|
||
This will generate two executables `simplenet_infer_fortran_bad` and | ||
`simplenet_infer_fortran_good`. | ||
|
||
These can be run and timed using: | ||
``` | ||
time ./simplenet_infer_fortran_bad | ||
``` | ||
and | ||
``` | ||
time ./simplenet_infer_fortran_good | ||
``` | ||
|
||
which will produce output like: | ||
``` | ||
99985792.0 100005792. 100025792. 100045792. 100065800. | ||
./simplenet_infer_fortran_bad 13.64s user 1.10s system 94% cpu 15.551 total | ||
``` | ||
and | ||
``` | ||
99985792.0 100005792. 100025792. 100045792. 100065800. | ||
./simplenet_infer_fortran_good 0.34s user 0.02s system 98% cpu 0.369 total | ||
``` | ||
|
||
We see that the 'good' approach is of the order of 40 times faster. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
module ml_mod | ||
|
||
use, intrinsic :: iso_fortran_env, only : sp=>real32 | ||
|
||
! Import our library for interfacing with PyTorch | ||
use ftorch | ||
|
||
implicit none | ||
|
||
private | ||
public ml_routine | ||
|
||
! Set working precision for reals | ||
integer, parameter :: wp = sp | ||
|
||
contains | ||
|
||
subroutine ml_routine(in_data, out_data) | ||
|
||
! Set up Fortran data structures | ||
real(wp), dimension(5), target, intent(in) :: in_data | ||
real(wp), dimension(5), target, intent(out) :: out_data | ||
|
||
! Set up Torch data structures | ||
! The net, a vector of input tensors, and a vector of output tensors | ||
integer, dimension(1) :: tensor_layout = [1] | ||
type(torch_tensor), dimension(1) :: input_tensors | ||
type(torch_tensor), dimension(1) :: output_tensors | ||
type(torch_model) :: torch_net | ||
|
||
! Get TorchScript model file | ||
character(len=128) :: model_torchscript_file | ||
|
||
! Create Torch input/output tensors from the above arrays | ||
call torch_tensor_from_array(input_tensors(1), in_data, tensor_layout, torch_kCPU) | ||
call torch_tensor_from_array(output_tensors(1), out_data, tensor_layout, torch_kCPU) | ||
|
||
! Load ML model | ||
model_torchscript_file = '../saved_simplenet_model.pt' | ||
call torch_model_load(torch_net, model_torchscript_file) | ||
|
||
! Infer | ||
call torch_model_forward(torch_net, input_tensors, output_tensors) | ||
|
||
! Cleanup | ||
call torch_delete(input_tensors) | ||
call torch_delete(output_tensors) | ||
call torch_delete(torch_net) | ||
|
||
end subroutine ml_routine | ||
|
||
end module ml_mod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
program inference | ||
|
||
! Import precision info from iso | ||
use, intrinsic :: iso_fortran_env, only : sp => real32 | ||
|
||
! Import the ml module | ||
use ml_mod, only : ml_routine | ||
|
||
implicit none | ||
|
||
! Set working precision for reals | ||
integer, parameter :: wp = sp | ||
integer :: i | ||
|
||
! Set up Fortran data structures | ||
real(wp), dimension(5), target :: in_data | ||
real(wp), dimension(5), target :: out_data | ||
real(wp), dimension(5), target :: sum_data | ||
|
||
! Initialise data | ||
in_data = [0.0, 1.0, 2.0, 3.0, 4.0] | ||
sum_data(:) = 0.0 | ||
|
||
! Loop over ml routine accumulating results | ||
do i = 1, 10000 | ||
call ml_routine(in_data, out_data) | ||
sum_data(:) = sum_data(:) + out_data(:) | ||
|
||
in_data(:) = in_data(:) + 1.0 | ||
end do | ||
|
||
! Write out the result of calling the net | ||
write (*,*) sum_data(:) | ||
|
||
end program inference |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
module ml_mod | ||
|
||
use, intrinsic :: iso_fortran_env, only : sp=>real32 | ||
|
||
! Import our library for interfacing with PyTorch | ||
use ftorch | ||
|
||
implicit none | ||
|
||
private | ||
public ml_init, ml_routine, ml_final | ||
|
||
! Set working precision for reals | ||
integer, parameter :: wp = sp | ||
|
||
! Set up Torch data structures | ||
! The net, a vector of input tensors, and a vector of output tensors | ||
integer, dimension(1) :: tensor_layout = [1] | ||
type(torch_tensor), dimension(1) :: input_tensors | ||
type(torch_tensor), dimension(1) :: output_tensors | ||
type(torch_model) :: torch_net | ||
|
||
! Get TorchScript model file | ||
character(len=128) :: model_torchscript_file | ||
|
||
contains | ||
|
||
subroutine ml_init() | ||
|
||
! Load ML model | ||
model_torchscript_file = '../saved_simplenet_model.pt' | ||
call torch_model_load(torch_net, model_torchscript_file) | ||
|
||
end subroutine ml_init | ||
|
||
subroutine ml_routine(in_data, out_data) | ||
|
||
! Set up Fortran data structures | ||
real(wp), dimension(5), target, intent(in) :: in_data | ||
real(wp), dimension(5), target, intent(out) :: out_data | ||
|
||
! Create Torch input/output tensors from the above arrays | ||
call torch_tensor_from_array(input_tensors(1), in_data, tensor_layout, torch_kCPU) | ||
call torch_tensor_from_array(output_tensors(1), out_data, tensor_layout, torch_kCPU) | ||
|
||
! Infer | ||
call torch_model_forward(torch_net, input_tensors, output_tensors) | ||
|
||
end subroutine ml_routine | ||
|
||
subroutine ml_final() | ||
! Cleanup | ||
call torch_delete(input_tensors) | ||
call torch_delete(output_tensors) | ||
call torch_delete(torch_net) | ||
|
||
end subroutine ml_final | ||
|
||
end module ml_mod |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
program inference | ||
|
||
! Import precision info from iso | ||
use, intrinsic :: iso_fortran_env, only : sp => real32 | ||
|
||
! Import the ml module | ||
use ml_mod, only : ml_init, ml_routine, ml_final | ||
|
||
implicit none | ||
|
||
! Set working precision for reals | ||
integer, parameter :: wp = sp | ||
integer :: i | ||
|
||
! Set up Fortran data structures | ||
real(wp), dimension(5), target :: in_data | ||
real(wp), dimension(5), target :: out_data | ||
real(wp), dimension(5), target :: sum_data | ||
|
||
! Initialise data | ||
in_data = [0.0, 1.0, 2.0, 3.0, 4.0] | ||
sum_data(:) = 0.0 | ||
|
||
call ml_init() | ||
|
||
! Loop over ml routine accumulating results | ||
do i = 1, 10000 | ||
call ml_routine(in_data, out_data) | ||
sum_data(:) = sum_data(:) + out_data(:) | ||
|
||
in_data(:) = in_data(:) + 1.0 | ||
end do | ||
|
||
call ml_final() | ||
|
||
! Write out the result of calling the net | ||
write (*,*) sum_data(:) | ||
|
||
end program inference |
Oops, something went wrong.