From 78b014b2d185ae7073cbe27b18a24310cc987adf Mon Sep 17 00:00:00 2001 From: Olivier Wulveryck Date: Mon, 30 Dec 2019 14:35:13 +0100 Subject: [PATCH] Feat/get input tensors (#172) * feat: flag for console output * feat: GetInputTensors * fix: test and bug fix * feat: bump version * feat: more tests --- doc/introduction/utils/draw.go | 10 ++++++++-- go.sum | 3 +++ io.go | 13 +++++++++++++ io_test.go | 20 ++++++++++++++++++++ 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/doc/introduction/utils/draw.go b/doc/introduction/utils/draw.go index 1ed90a47..0069d707 100644 --- a/doc/introduction/utils/draw.go +++ b/doc/introduction/utils/draw.go @@ -2,6 +2,7 @@ package main import ( "encoding/base64" + "flag" "fmt" "image" "image/color" @@ -20,8 +21,13 @@ var ( func main() { reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(img8)) im, _, _ = image.Decode(reader) - //outputConsole() - outputValues() + console := flag.Bool("c", false, "console output") + flag.Parse() + if *console { + outputConsole() + } else { + outputValues() + } } func outputConsole() { diff --git a/go.sum b/go.sum index c0846f3f..71d93168 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,7 @@ github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU= golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81 h1:00VmoueYNlNz/aHIilyyQz/MHSqGoWJzpFv/HW8xpzI= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= @@ -81,9 +82,11 @@ gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM= gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= diff --git a/io.go b/io.go index 2e0bcb6d..9309ae25 100644 --- a/io.go +++ b/io.go @@ -38,3 +38,16 @@ func (m *Model) GetOutputTensors() ([]tensor.Tensor, error) { } return output, nil } + +// GetInpuTensors from the graph. This function is useful to get informations if the tensor is a placeholder +// and does not contain any data yet. +func (m *Model) GetInputTensors() []tensor.Tensor { + output := make([]tensor.Tensor, len(m.Input)) + for i := range m.Input { + n := m.backend.Node(int64(m.Input[i])) + if n != nil { + output[i] = n.(DataCarrier).GetTensor() + } + } + return output +} diff --git a/io_test.go b/io_test.go index ddfb8252..6a13279e 100644 --- a/io_test.go +++ b/io_test.go @@ -17,3 +17,23 @@ func TestSetInput_nil_model(t *testing.T) { err := m.SetInput(0, tens) t.Fatal("should have paniced but have passed with error", err) } + +func TestGetInputTensors(t *testing.T) { + backend := newTestBackend() + n1 := backend.NewNode() + backend.AddNode(n1) + n2 := backend.NewNode() + backend.AddNode(n2) + n2.(*nodeTest).SetTensor(tensor.NewDense(tensor.Float32, []int{1, 1})) + model := &Model{ + Input: []int64{n1.ID(), n2.ID()}, + backend: backend, + } + input := model.GetInputTensors() + if len(input) != 2 { + t.FailNow() + } + if input[0] != nil || input[1] == nil { + t.Fail() + } +}