Skip to content

Commit

Permalink
RSDK-2380 Check if labels file is split by spaces or commas (#2178)
Browse files Browse the repository at this point in the history
  • Loading branch information
kim-mishra authored Apr 10, 2023
1 parent b6a992c commit c9ff68f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .artifact/tree.json
Original file line number Diff line number Diff line change
Expand Up @@ -51473,6 +51473,10 @@
"hash": "90790cb342b79585a67dc2e4ad5d8be6",
"size": 88663
},
"lorem.txt": {
"hash": "f668c7c4651f39df473a8355fef9fa38",
"size": 1162
},
"object_classifier.tflite": {
"hash": "947c25596dab54a519e08a3d7dfaf6fd",
"size": 4276352
Expand Down
7 changes: 7 additions & 0 deletions services/vision/builtin/tflite_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"go.viam.com/rdk/vision/classification"
)

var LABEL_OUTPUT_MISMATCH = errors.New("Invalid Label File: Number of labels does not match number of model outputs. Labels must be separated by a newline, comma or space.")

// TFLiteClassifierConfig specifies the fields necessary for creating a TFLite classifier.
type TFLiteClassifierConfig struct {
// this should come from the attributes part of the classifier config
Expand Down Expand Up @@ -112,8 +114,13 @@ func unpackClassificationTensor(ctx context.Context, tensor []interface{},
default:
return nil, errors.New("output type not valid. try uint8 or float32")
}

out := make(classification.Classifications, 0, len(outConf))
if len(labels) > 0 {
if len(labels) != len(outConf) {
return nil, LABEL_OUTPUT_MISMATCH
}

for i, c := range outConf {
out = append(out, classification.NewClassification(c, labels[i]))
}
Expand Down
10 changes: 10 additions & 0 deletions services/vision/builtin/tflite_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,16 @@ func loadLabels(filename string) ([]string, error) {
for scanner.Scan() {
labels = append(labels, scanner.Text())
}

// if the labels come out as one line, try splitting that line by spaces or commas to extract labels
if len(labels) == 1 {
labels = strings.Split(labels[0], ",")
}

if len(labels) == 1 {
labels = strings.Split(labels[0], " ")
}

return labels, nil
}

Expand Down
29 changes: 29 additions & 0 deletions services/vision/builtin/tflite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/edaniels/golog"
"github.com/nfnt/resize"
"go.viam.com/test"
"go.viam.com/utils/artifact"

Expand Down Expand Up @@ -254,3 +255,31 @@ func TestMoreClassifierModels(t *testing.T) {
test.That(t, bestClass[0].Label(), test.ShouldResemble, "292")
test.That(t, bestClass[0].Score(), test.ShouldBeGreaterThan, 0.93)
}

func TestInvalidLabels(t *testing.T) {
ctx := context.Background()

pic, err := rimage.NewImageFromFile(artifact.MustPath("vision/tflite/redpanda.jpeg"))
test.That(t, err, test.ShouldBeNil)

modelLoc := artifact.MustPath("vision/tflite/mobilenetv2_class.tflite")
labelPath := artifact.MustPath("vision/classification/object_labels.txt")
numThreads := 2

labels, err := loadLabels(labelPath)
model, err := addTFLiteModel(ctx, modelLoc, &numThreads)
resizedImg := resize.Resize(100, 100, pic, resize.Bilinear)
outTensor, err := tfliteInfer(ctx, model, resizedImg)

classifications, err := unpackClassificationTensor(ctx, outTensor, model, labels)
test.That(t, err, test.ShouldResemble, LABEL_OUTPUT_MISMATCH)
test.That(t, classifications, test.ShouldBeNil)
}

func TestSpaceDelineatedLabels(t *testing.T) {
labelPath := artifact.MustPath("vision/classification/lorem.txt")

labels, err := loadLabels(labelPath)
test.That(t, err, test.ShouldBeNil)
test.That(t, len(labels), test.ShouldEqual, 10)
}

0 comments on commit c9ff68f

Please sign in to comment.