Skip to content

Commit

Permalink
Add customer encoder to support more Massage types
Browse files Browse the repository at this point in the history
Now supporting SequenceExamples
  • Loading branch information
Emil Larsson committed Dec 13, 2020
1 parent 21f4eec commit 761add4
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 55 deletions.
71 changes: 16 additions & 55 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,20 @@ package cmd

import (
"bufio"
"encoding/json"
"errors"
"fmt"
"github.com/spf13/cobra"
"google.golang.org/protobuf/proto"
"io"
"math"
"os"

"github.com/spf13/cobra"
"google.golang.org/protobuf/proto"

protobuf "github.com/emla2805/tfr/protobuf"
tfr "github.com/emla2805/tfr/protobuf"
"github.com/emla2805/tfr/utils"
)

var numberRecords int
var record string

var rootCmd = &cobra.Command{
Use: "tfr {file ... | -}",
Expand Down Expand Up @@ -85,19 +83,27 @@ reads serialized .tfrecord files and outputs results as JSON on standard output.
scanner := bufio.NewScanner(multi)
scanner.Split(utils.ScanTFRecord)

var example proto.Message
switch record {
case "example":
example = &protobuf.Example{}
case "sequence_example":
example = &protobuf.SequenceExample{}
default:
example = &protobuf.Example{}
}

count := 0
for scanner.Scan() {
if count >= numberRecords {
break
}
example := &protobuf.Example{}
err := proto.Unmarshal(scanner.Bytes(), example)
if err != nil {
return err
}
// convert to local example
ex := NewExample(example)
jsonBytes, err := json.Marshal(ex)

jsonBytes, err := utils.Marshal(example)
if err != nil {
return err
}
Expand All @@ -116,54 +122,9 @@ func Execute() {
}
}

type Example struct {
Features map[string]Feature `json:"features"`
}

func NewExample(example *tfr.Example) *Example {
feature := make(Feature)
ex := &Example{
Features: make(map[string]Feature),
}
for k, v := range example.Features.Feature {
feature[k] = v
}
ex.Features["feature"] = feature
return ex
}

type Feature map[string]*tfr.Feature

func (f Feature) MarshalJSON() ([]byte, error) {
output := make(map[string]interface{})
for k, v := range f {
var typ string
value := make(map[string]interface{})
typMap := make(map[string]interface{})
switch t := v.Kind.(type) {
case *tfr.Feature_BytesList:
typ = "bytesList"
stringList := []string{}
for _, byts := range t.BytesList.Value {
stringList = append(stringList, string(byts))
}
value["value"] = stringList
case *tfr.Feature_FloatList:
typ = "floatList"
value["value"] = t.FloatList.Value
case *tfr.Feature_Int64List:
typ = "int64List"
value["value"] = t.Int64List.Value
}
typMap[typ] = value
output[k] = typMap
}

return json.Marshal(output)
}

func init() {
rootCmd.Flags().IntVarP(&numberRecords, "number", "n", math.MaxInt32, "number of records to show")
rootCmd.Flags().StringVarP(&record, "record", "r", "example", "record type { example | sequence_example }")
}

func isInputFromPipe() bool {
Expand Down
240 changes: 240 additions & 0 deletions utils/encode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package utils

import (
"errors"
"fmt"
"google.golang.org/protobuf/proto"
pref "google.golang.org/protobuf/reflect/protoreflect"
"math/bits"
"sort"
"strconv"
"unicode/utf8"
)

type jsonWriter struct {
buf []byte
}

func (w *jsonWriter) write(s string) {
w.buf = append(w.buf, s...)
}

func (w *jsonWriter) writeString(s string) error {
var err error
if w.buf, err = appendString(w.buf, s); err != nil {
return err
}
return nil
}

func Marshal(m proto.Message) ([]byte, error) {
w := jsonWriter{}
err := w.marshalMessage(m.ProtoReflect())
return w.buf, err
}

// marshalMessage marshals the given protoreflect.Message.
func (w *jsonWriter) marshalMessage(m pref.Message) error {
if err := w.marshalFields(m); err != nil {
return err
}

return nil
}

// marshalFields marshals the fields in the given protoreflect.Message.
func (w *jsonWriter) marshalFields(m pref.Message) error {
messageDesc := m.Descriptor()

w.write("{")
defer w.write("}")
firstField := true

// Marshal out known fields.
fieldDescs := messageDesc.Fields()
for i := 0; i < fieldDescs.Len(); {
fd := fieldDescs.Get(i)
if od := fd.ContainingOneof(); od != nil {
fd = m.WhichOneof(od)
i += od.Fields().Len()
if fd == nil {
continue // unpopulated oneofs are not affected by EmitUnpopulated
}
} else {
i++
}

val := m.Get(fd)
if !firstField {
w.write(",")
}
if err := w.marshalField(val, fd); err != nil {
return err
}
firstField = false
}
return nil
}

func (w *jsonWriter) marshalField(val pref.Value, fd pref.FieldDescriptor) error {
w.write(`"` + fd.JSONName() + `":`)
return w.marshalValue(val, fd)
}

// marshalValue marshals the given protoreflect.Value.
func (w *jsonWriter) marshalValue(val pref.Value, fd pref.FieldDescriptor) error {
switch {
case fd.IsList():
return w.marshalList(val.List(), fd)
case fd.IsMap():
return w.marshalMap(val.Map(), fd)
default:
return w.marshalSingular(val, fd)
}
}

// marshalList marshals the given protoreflect.List.
func (w *jsonWriter) marshalList(list pref.List, fd pref.FieldDescriptor) error {
w.write("[")
defer w.write("]")

comma := ""
for i := 0; i < list.Len(); i++ {
w.write(comma)
item := list.Get(i)
if err := w.marshalSingular(item, fd); err != nil {
return err
}
comma = ","
}
return nil
}

type mapEntry struct {
key pref.MapKey
value pref.Value
}

// marshalMap marshals given protoreflect.Map.
func (w *jsonWriter) marshalMap(mmap pref.Map, fd pref.FieldDescriptor) error {
// Get a sorted list based on keyType first.
entries := make([]mapEntry, 0, mmap.Len())
mmap.Range(func(key pref.MapKey, val pref.Value) bool {
entries = append(entries, mapEntry{key: key, value: val})
return true
})
sortMap(fd.MapKey().Kind(), entries)

w.write(`{`)
defer w.write(`}`)
comma := ""

// Write out sorted list.
for _, entry := range entries {
w.write(comma)
w.write(`"` + entry.key.String() + `":`)
if err := w.marshalSingular(entry.value, fd.MapValue()); err != nil {
return err
}
comma = ","
}
return nil
}

// sortMap orders list based on value of key field for deterministic ordering.
func sortMap(keyKind pref.Kind, values []mapEntry) {
sort.Slice(values, func(i, j int) bool {
switch keyKind {
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind,
pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
return values[i].key.Int() < values[j].key.Int()

case pref.Uint32Kind, pref.Fixed32Kind,
pref.Uint64Kind, pref.Fixed64Kind:
return values[i].key.Uint() < values[j].key.Uint()
}
return values[i].key.String() < values[j].key.String()
})
}

// marshalSingular marshals the given non-repeated field value. This includes
// all scalar types, enums, messages, and groups.
func (w *jsonWriter) marshalSingular(val pref.Value, fd pref.FieldDescriptor) error {
if !val.IsValid() {
return nil
}

switch kind := fd.Kind(); kind {
case pref.Int64Kind:
w.write(val.String())

case pref.FloatKind:
w.write(val.String())

case pref.BytesKind:
if err := w.writeString(string(val.Bytes())); err != nil {
return err
}

case pref.MessageKind, pref.GroupKind:
if err := w.marshalMessage(val.Message()); err != nil {
return err
}

default:
panic(fmt.Sprintf("%v has unknown kind: %v", fd.FullName(), kind))
}
return nil
}

// Sentinel error used for indicating invalid UTF-8.
var errInvalidUTF8 = errors.New("invalid UTF-8")

func appendString(out []byte, in string) ([]byte, error) {
out = append(out, '"')
i := indexNeedEscapeInString(in)
in, out = in[i:], append(out, in[:i]...)
for len(in) > 0 {
switch r, n := utf8.DecodeRuneInString(in); {
case r == utf8.RuneError && n == 1:
return out, errInvalidUTF8
case r < ' ' || r == '"' || r == '\\':
out = append(out, '\\')
switch r {
case '"', '\\':
out = append(out, byte(r))
case '\b':
out = append(out, 'b')
case '\f':
out = append(out, 'f')
case '\n':
out = append(out, 'n')
case '\r':
out = append(out, 'r')
case '\t':
out = append(out, 't')
default:
out = append(out, 'u')
out = append(out, "0000"[1+(bits.Len32(uint32(r))-1)/4:]...)
out = strconv.AppendUint(out, uint64(r), 16)
}
in = in[n:]
default:
i := indexNeedEscapeInString(in[n:])
in, out = in[n+i:], append(out, in[:n+i]...)
}
}
out = append(out, '"')
return out, nil
}

// indexNeedEscapeInString returns the index of the character that needs
// escaping. If no characters need escaping, this returns the input length.
func indexNeedEscapeInString(s string) int {
for i, r := range s {
if r < ' ' || r == '\\' || r == '"' || r == utf8.RuneError {
return i
}
}
return len(s)
}

0 comments on commit 761add4

Please sign in to comment.