Skip to content

Commit

Permalink
Feat/get input tensors (#172)
Browse files Browse the repository at this point in the history
* feat: flag for console output

* feat: GetInputTensors

* fix: test and bug fix

* feat: bump version

* feat: more tests
  • Loading branch information
owulveryck authored Dec 30, 2019
1 parent c21037c commit 78b014b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
10 changes: 8 additions & 2 deletions doc/introduction/utils/draw.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"encoding/base64"
"flag"
"fmt"
"image"
"image/color"
Expand All @@ -20,8 +21,13 @@ var (
func main() {
reader := base64.NewDecoder(base64.StdEncoding, strings.NewReader(img8))
im, _, _ = image.Decode(reader)
//outputConsole()
outputValues()
console := flag.Bool("c", false, "console output")
flag.Parse()
if *console {
outputConsole()
} else {
outputValues()
}
}

func outputConsole() {
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2 h1:y102fOLFqhV41b+4GPiJoa0k/x+pJcEi2/HB1Y5T6fU=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81 h1:00VmoueYNlNz/aHIilyyQz/MHSqGoWJzpFv/HW8xpzI=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
Expand All @@ -81,9 +82,11 @@ gonum.org/v1/gonum v0.0.0-20190226202314-149afe6ec0b6/go.mod h1:jevfED4GnIEnJrWW
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee h1:4pVWuAEGpaPZ7dPfd6aA8LyDNzMA2RKCxAS/XNCLZUM=
gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU=
gonum.org/v1/netlib v0.0.0-20190221094214-0632e2ebbd2d/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.27/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
Expand Down
13 changes: 13 additions & 0 deletions io.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,16 @@ func (m *Model) GetOutputTensors() ([]tensor.Tensor, error) {
}
return output, nil
}

// GetInpuTensors from the graph. This function is useful to get informations if the tensor is a placeholder
// and does not contain any data yet.
func (m *Model) GetInputTensors() []tensor.Tensor {
output := make([]tensor.Tensor, len(m.Input))
for i := range m.Input {
n := m.backend.Node(int64(m.Input[i]))
if n != nil {
output[i] = n.(DataCarrier).GetTensor()
}
}
return output
}
20 changes: 20 additions & 0 deletions io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,23 @@ func TestSetInput_nil_model(t *testing.T) {
err := m.SetInput(0, tens)
t.Fatal("should have paniced but have passed with error", err)
}

func TestGetInputTensors(t *testing.T) {
backend := newTestBackend()
n1 := backend.NewNode()
backend.AddNode(n1)
n2 := backend.NewNode()
backend.AddNode(n2)
n2.(*nodeTest).SetTensor(tensor.NewDense(tensor.Float32, []int{1, 1}))
model := &Model{
Input: []int64{n1.ID(), n2.ID()},
backend: backend,
}
input := model.GetInputTensors()
if len(input) != 2 {
t.FailNow()
}
if input[0] != nil || input[1] == nil {
t.Fail()
}
}

0 comments on commit 78b014b

Please sign in to comment.