-
-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: mnist example (onnx version 1.3) can successfully run with Gorg…
…onia * chore: ignore binary * feat: create the skeleton of a utility to run a model from the zoo * chore: prepare the implementation of the auto-padding * feat: the utility is working * feat: add the same_upper auto-paddding * chore: new informations * chore: add some doc
- Loading branch information
1 parent
49423a4
commit f37e05b
Showing
6 changed files
with
147 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ | |
doc/doc | ||
example/gorgonia/numpy | ||
example/gorgonia/gorgonia | ||
examples/model_zoo_executor/model_zoo_executor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# About | ||
|
||
This is a simple utility that runs a model from the model zoo thanks to the Gorgonia backend | ||
|
||
## Example | ||
|
||
Download a pre-trained [model from the zoo](https://github.com/onnx/models) (for now, only [MNIST](https://github.com/onnx/models/tree/master/mnist) is known to work) | ||
|
||
then smply run: | ||
|
||
`go run main.go -model /tmp/mnist/model.onnx -input /tmp/mnist/test_data_set_0/input_0.pb -output /tmp/mnist/test_data_set_0/output_0.pb` | ||
|
||
The utility evaluates the model and check if the computed output is equal to the expected output (within a delta of 5e-3). | ||
If the result is ok, it displays the result: | ||
|
||
`[975.67035 -618.7244 6574.5684 668.0278 -917.27057 -1671.6357 -1952.7606 -61.54949 -777.17645 -1439.5311]` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package main | ||
|
||
import ( | ||
"flag" | ||
"fmt" | ||
"io/ioutil" | ||
"log" | ||
"os" | ||
|
||
"github.com/owulveryck/onnx-go" | ||
"github.com/owulveryck/onnx-go/backend/x/gorgonnx" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func main() { | ||
model := flag.String("model", "model.onnx", "path to the model file") | ||
input := flag.String("input", "test_data_set_0/input_0.pb", "path to the input file") | ||
output := flag.String("output", "test_data_set_0/output_0.pb", "path to the output file") | ||
h := flag.Bool("h", false, "help") | ||
flag.Parse() | ||
if *h { | ||
flag.Usage() | ||
os.Exit(0) | ||
} | ||
for _, f := range []string{*model, *input, *output} { | ||
if _, err := os.Stat(f); err != nil && os.IsNotExist(err) { | ||
log.Fatalf("%v does not exist", f) | ||
} | ||
} | ||
// Create a backend receiver | ||
backend := gorgonnx.NewGraph() | ||
// Create a model and set the execution backend | ||
m := onnx.NewModel(backend) | ||
|
||
// read the onnx model | ||
b, err := ioutil.ReadFile(*model) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
// Decode it into the model | ||
err = m.UnmarshalBinary(b) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
// Set the first input, the number depends of the model | ||
// TODO | ||
b, err = ioutil.ReadFile(*input) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
inputT, err := onnx.NewTensor(b) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
m.SetInput(0, inputT) | ||
err = backend.Run() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
b, err = ioutil.ReadFile(*output) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
outputT, err := onnx.NewTensor(b) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
computedOutputT, err := m.GetOutputTensors() | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
assert.InDeltaSlice(&testingT{}, outputT.Data(), computedOutputT[0].Data(), 5e-3, "the two tensors should be equal.") | ||
fmt.Println(computedOutputT[0].Data()) | ||
} | ||
|
||
type testingT struct{} | ||
|
||
func (t *testingT) Errorf(format string, args ...interface{}) { | ||
log.Fatalf(format, args...) | ||
} |