Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-2380 Check if labels file is split by spaces or commas #2178

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion services/vision/builtin/tflite_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package builtin

import (
"context"
"fmt"
"image"
"runtime"
"strconv"
Expand Down Expand Up @@ -82,9 +83,9 @@ func NewTFLiteClassifier(ctx context.Context, conf *vision.VisModelConfig,
if err != nil {
return nil, err
}

classifications, err := unpackClassificationTensor(ctx, outTensor, model, labels)
if err != nil {
logger.Error(err)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the best way to surface this error to the user? Logging it here logs it over and over and over again, but just returning the error doesn't seem to surface it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this function that gets returned will run every time an image gets classified. My first thought is that if we hate that, maybe try a logger.Fatal() which will kill everything after printing it out once. If that seems like overkill, maybe only fatal depending on the type of error. Either way it seems like don't want to keep on trying to classify this over and over again if we're definitely not unpacking the tensors correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm the linter complains about logger.Fatal, do you know of a workaround?

return nil, err
}
return classifications, nil
Expand Down Expand Up @@ -112,6 +113,11 @@ func unpackClassificationTensor(ctx context.Context, tensor []interface{},
default:
return nil, errors.New("output type not valid. try uint8 or float32")
}

if len(labels) != len(outConf) {
return nil, errors.New(fmt.Sprintf("Invalid Label File: Number of labels (%v) does not match number of model outputs (%v). Labels must be separated by a newline, comma or space.", len(labels), len(outConf)))
}

out := make(classification.Classifications, 0, len(outConf))
if len(labels) > 0 {
for i, c := range outConf {
Expand Down
10 changes: 10 additions & 0 deletions services/vision/builtin/tflite_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,17 @@ func loadLabels(filename string) ([]string, error) {
for scanner.Scan() {
labels = append(labels, scanner.Text())
}

if len(labels) == 1 {
labels = strings.Split(labels[0], " ")
}

if len(labels) == 1 {
labels = strings.Split(labels[0], ",")
}

return labels, nil

}

// getIndex just returns the index of an int in an array of ints
Expand Down