Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization server and unit tests for visualization server #1647

Merged
merged 9 commits into from
Jul 25, 2019
2 changes: 2 additions & 0 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
api.RegisterRunServiceServer(s, server.NewRunServer(resourceManager))
api.RegisterJobServiceServer(s, server.NewJobServer(resourceManager))
api.RegisterReportServiceServer(s, server.NewReportServer(resourceManager))
api.RegisterVisualizationServiceServer(s, server.NewVisualizationServer(resourceManager))

// Register reflection service on gRPC server.
reflection.Register(s)
Expand All @@ -106,6 +107,7 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
registerHttpHandlerFromEndpoint(api.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, mux)

// Create a top level mux to include both pipeline upload server and gRPC servers.
topMux := http.NewServeMux()
Expand Down
2 changes: 2 additions & 0 deletions backend/src/apiserver/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ go_library(
"run_server.go",
"test_util.go",
"util.go",
"visualization.go",
ajchili marked this conversation as resolved.
Show resolved Hide resolved
],
importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/server",
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -50,6 +51,7 @@ go_test(
"run_metric_util_test.go",
"run_server_test.go",
"util_test.go",
"visualization_test.go",
],
data = glob(["test/**/*"]), # keep
embed = [":go_default_library"],
Expand Down
108 changes: 108 additions & 0 deletions backend/src/apiserver/server/visualization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package server

import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/kubeflow/pipelines/backend/src/apiserver/resource"
"github.com/kubeflow/pipelines/backend/src/common/util"
"io/ioutil"
"net/http"
"net/url"
"strings"
)

type VisualizationServer struct {
resourceManager *resource.ResourceManager
serviceURL string
}

func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *go_client.CreateVisualizationRequest) (*go_client.Visualization, error) {
if err := s.ValidateCreateVisualizationRequest(request.Visualization); err != nil {
return nil, err
}
arguments, err := s.GetArgumentsAsJSONFromVisualization(request.Visualization)
if err != nil {
return nil, err
}
pythonArguments := s.CreatePythonArgumentsFromTypeAndJSON(request.Visualization.Type, arguments)
body, err := s.GenerateVisualization(pythonArguments)
if err != nil {
return nil, err
}
request.Visualization.Html = string(body)
return request.Visualization, nil
}

// ValidateCreateVisualizationRequest ensures that a go_client.Visualization
// object has valid values.
// It returns an error if a go_client.Visualization object does not have valid
// values.
func (s *VisualizationServer) ValidateCreateVisualizationRequest(visualization *go_client.Visualization) error {
ajchili marked this conversation as resolved.
Show resolved Hide resolved
ajchili marked this conversation as resolved.
Show resolved Hide resolved
if len(visualization.InputPath) == 0 {
return util.NewInvalidInputError("A visualization requires an InputPath to be provided. Received %s", visualization.InputPath)
}
// Manually set Arguments to empty JSON if nothing is provided. This is done
// because visualizations such as TFDV and TFMA only require an InputPath to
// provided for a visualization to be generated. If no JSON is provided
// json.Valid will fail without this check as an empty string is provided for
// those visualizations.
if len(visualization.Arguments) == 0 {
visualization.Arguments = "{}"
}
if !json.Valid([]byte(visualization.Arguments)) {
return util.NewInvalidInputError("A visualization requires valid JSON to be provided as Arguments. Received %s", visualization.Arguments)
}
return nil
}

// GetArgumentsAsJSONFromVisualization will convert the values within a
// go_client.Visualization object to valid JSON that can be used to pass
// arguments to the python visualization service.
// It returns the generated JSON as an array of bytes and any error that is
// encountered.
func (s *VisualizationServer) GetArgumentsAsJSONFromVisualization(visualization *go_client.Visualization) ([]byte, error) {
var arguments map[string]interface{}
if err := json.Unmarshal([]byte(visualization.Arguments), &arguments); err != nil {
return nil, util.Wrap(err, "Unable to parse provided JSON.")
}
arguments["input_path"] = visualization.InputPath
ajchili marked this conversation as resolved.
Show resolved Hide resolved
args, err := json.Marshal(arguments)
if err != nil {
return nil, util.Wrap(err, "Unable to compose provided JSON as string.")
}
return args, nil
}

// CreatePythonArgumentsFromTypeAndJSON converts the values within a
// go_client.Visualization object to those expected by the python visualization
// service.
// It returns the converted values as a string.
func (s *VisualizationServer) CreatePythonArgumentsFromTypeAndJSON(visualizationType go_client.Visualization_Type, arguments []byte) string {
var _visualizationType = strings.ToLower(go_client.Visualization_Type_name[int32(visualizationType)])
return fmt.Sprintf("--type %s --arguments '%s'", _visualizationType, arguments)
ajchili marked this conversation as resolved.
Show resolved Hide resolved
}

// GenerateVisualization communicates with the python visualization service to
// generate HTML visualizations from specified arguments.
// It returns the generated HTML as a string and any error that is encountered.
func (s *VisualizationServer) GenerateVisualization(arguments string) ([]byte, error) {
resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}})
if err != nil {
return nil, util.Wrap(err, "Unable to initialize visualization request.")
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(resp.Status)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, util.Wrap(err, "Unable to parse visualization response.")
}
return body, nil
}

func NewVisualizationServer(resourceManager *resource.ResourceManager) *VisualizationServer {
return &VisualizationServer{resourceManager: resourceManager, serviceURL: "http://visualization-service.kubeflow"}
}
131 changes: 131 additions & 0 deletions backend/src/apiserver/server/visualization_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package server

import (
"github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)

func TestValidateCreateVisualizationRequest(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
err := server.ValidateCreateVisualizationRequest(apiVisualization)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "",
}
err := server.ValidateCreateVisualizationRequest(apiVisualization)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_InputPathIsEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "",
Arguments: "{}",
}
err := server.ValidateCreateVisualizationRequest(apiVisualization)
assert.Contains(t, err.Error(), "A visualization requires an InputPath to be provided. Received")
}

func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{",
}
err := server.ValidateCreateVisualizationRequest(apiVisualization)
assert.Contains(t, err.Error(), "A visualization requires valid JSON to be provided as Arguments. Received {")
}

func TestGetArgumentsAsJSONFromVisualization(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
arguments, err := server.GetArgumentsAsJSONFromVisualization(apiVisualization)
assert.Equal(t, []byte("{\"input_path\":\"gs://ml-pipeline/roc/data.csv\"}"), arguments)
assert.Nil(t, err)
}

func TestGetArgumentsAsJSONFromVisualization_ArgumentsNotValidJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{",
}
arguments, err := server.GetArgumentsAsJSONFromVisualization(apiVisualization)
assert.Nil(t, arguments)
assert.Contains(t, err.Error(), "Unable to parse provided JSON.")
}

func TestCreatePythonArgumentsFromTypeAndJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
apiVisualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
arguments := []byte("{\"input_path\": \"gs://ml-pipeline/roc/data.csv\"}")
pythonArguments := server.CreatePythonArgumentsFromTypeAndJSON(apiVisualization.Type, arguments)
assert.Equal(t, "--type roc_curve --arguments '{\"input_path\": \"gs://ml-pipeline/roc/data.csv\"}'", pythonArguments)
}

func TestGenerateVisualization(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/", req.URL.String())
rw.Write([]byte("roc_curve"))
}))
defer httpServer.Close()
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
body, err := server.GenerateVisualization("--type roc_curve --arguments '{\"input_path\": \"gs://ml-pipeline/roc/data.csv\"}'")
assert.Equal(t, []byte("roc_curve"), body)
assert.Nil(t, err)
}

func TestGenerateVisualization_ServerError(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/", req.URL.String())
rw.WriteHeader(500)
}))
defer httpServer.Close()
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
body, err := server.GenerateVisualization("--type roc_curve --arguments '{\"input_path\": \"gs://ml-pipeline/roc/data.csv\"}'")
assert.Nil(t, body)
assert.Equal(t, "500 Internal Server Error", err.Error())
}