Skip to content

Commit

Permalink
DATA-3467: Cloud Inference CLI (viamrobotics#4748)
Browse files Browse the repository at this point in the history
  • Loading branch information
vpandiarajan20 authored and vijayvuyyuru committed Feb 11, 2025
1 parent faee5ac commit ad9976f
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 10 deletions.
46 changes: 46 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,52 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
},
},
{
Name: "infer",
Usage: "run cloud hosted inference on an image",
UsageText: createUsageText("inference infer", []string{
generalFlagOrgID, inferenceFlagFileOrgID, inferenceFlagFileID,
inferenceFlagFileLocationID, inferenceFlagModelOrgID, inferenceFlagModelName, inferenceFlagModelVersion,
}, true, false),
Flags: []cli.Flag{
&cli.StringFlag{
Name: generalFlagOrgID,
Usage: "organization ID that is executing the inference job",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileOrgID,
Usage: "organization ID that owns the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileID,
Usage: "file ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileLocationID,
Usage: "location ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelOrgID,
Usage: "organization ID that hosts the model to use to run inference",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelName,
Usage: "name of the model to use to run inference",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelVersion,
Usage: "version of the model to use to run inference",
Required: true,
},
},
Action: createCommandWithT[mlInferenceInferArgs](MLInferenceInferAction),
},
{
Name: "version",
Usage: "print version info for this program",
Expand Down
2 changes: 2 additions & 0 deletions cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -544,6 +545,7 @@ func (c *viamClient) ensureLoggedInInner() error {
c.packageClient = packagepb.NewPackageServiceClient(conn)
c.datasetClient = datasetpb.NewDatasetServiceClient(conn)
c.mlTrainingClient = mltrainingpb.NewMLTrainingServiceClient(conn)
c.mlInferenceClient = mlinferencepb.NewMLInferenceServiceClient(conn)
c.buildClient = buildpb.NewBuildServiceClient(conn)

return nil
Expand Down
22 changes: 12 additions & 10 deletions cli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -71,16 +72,17 @@ var errNoShellService = errors.New("shell service is not enabled on this machine
// viamClient wraps a cli.Context and provides all the CLI command functionality
// needed to talk to the app and data services but not directly to robot parts.
type viamClient struct {
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
mlInferenceClient mlinferencepb.MLInferenceServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow

selectedOrg *apppb.Organization
selectedLoc *apppb.Location
Expand Down
122 changes: 122 additions & 0 deletions cli/ml_inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package cli

import (
"context"
"fmt"
"strings"

"github.com/pkg/errors"
"github.com/urfave/cli/v2"
v1 "go.viam.com/api/app/data/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
)

const (
inferenceFlagFileOrgID = "file-org-id"
inferenceFlagFileID = "file-id"
inferenceFlagFileLocationID = "file-location-id"
inferenceFlagModelOrgID = "model-org-id"
inferenceFlagModelName = "model-name"
inferenceFlagModelVersion = "model-version"
)

type mlInferenceInferArgs struct {
OrgID string
FileOrgID string
FileID string
FileLocationID string
ModelOrgID string
ModelName string
ModelVersion string
}

// MLInferenceInferAction is the corresponding action for 'inference infer'.
func MLInferenceInferAction(c *cli.Context, args mlInferenceInferArgs) error {
client, err := newViamClient(c)
if err != nil {
return err
}

_, err = client.mlRunInference(
args.OrgID, args.FileOrgID, args.FileID, args.FileLocationID,
args.ModelOrgID, args.ModelName, args.ModelVersion)
if err != nil {
return err
}
return nil
}

// mlRunInference runs inference on an image with the specified parameters.
func (c *viamClient) mlRunInference(orgID, fileOrgID, fileID, fileLocation, modelOrgID,
modelName, modelVersion string,
) (*mlinferencepb.GetInferenceResponse, error) {
if err := c.ensureLoggedIn(); err != nil {
return nil, err
}

req := &mlinferencepb.GetInferenceRequest{
OrganizationId: orgID,
BinaryId: &v1.BinaryID{
FileId: fileID,
OrganizationId: fileOrgID,
LocationId: fileLocation,
},
RegistryItemId: fmt.Sprintf("%s:%s", modelOrgID, modelName),
RegistryItemVersion: modelVersion,
}

resp, err := c.mlInferenceClient.GetInference(context.Background(), req)
if err != nil {
return nil, errors.Wrapf(err, "received error from server")
}
c.printInferenceResponse(resp)
return resp, nil
}

// printInferenceResponse prints a neat representation of the GetInferenceResponse.
func (c *viamClient) printInferenceResponse(resp *mlinferencepb.GetInferenceResponse) {
printf(c.c.App.Writer, "Inference Response:")
printf(c.c.App.Writer, "Output Tensors:")
if resp.OutputTensors != nil {
for name, tensor := range resp.OutputTensors.Tensors {
printf(c.c.App.Writer, " Tensor Name: %s", name)
printf(c.c.App.Writer, " Shape: %v", tensor.Shape)
if tensor.Tensor != nil {
var sb strings.Builder
for i, value := range tensor.GetDoubleTensor().GetData() {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%.4f", value))
}
printf(c.c.App.Writer, " Values: [%s]", sb.String())
} else {
printf(c.c.App.Writer, " No values available.")
}
}
} else {
printf(c.c.App.Writer, " No output tensors.")
}

printf(c.c.App.Writer, "Annotations:")
printf(c.c.App.Writer, "Bounding Box Format: [x_min, y_min, x_max, y_max]")
if resp.Annotations != nil {
for _, bbox := range resp.Annotations.Bboxes {
printf(c.c.App.Writer, " Bounding Box ID: %s, Label: %s",
bbox.Id, bbox.Label)
printf(c.c.App.Writer, " Coordinates: [%f, %f, %f, %f]",
bbox.XMinNormalized, bbox.YMinNormalized, bbox.XMaxNormalized, bbox.YMaxNormalized)
if bbox.Confidence != nil {
printf(c.c.App.Writer, " Confidence: %.4f", *bbox.Confidence)
}
}
for _, classification := range resp.Annotations.Classifications {
printf(c.c.App.Writer, " Classification Label: %s", classification.Label)
if classification.Confidence != nil {
printf(c.c.App.Writer, " Confidence: %.4f", *classification.Confidence)
}
}
} else {
printf(c.c.App.Writer, " No annotations.")
}
}

0 comments on commit ad9976f

Please sign in to comment.