Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization server and unit tests for visualization server #1647

Merged
merged 9 commits into from
Jul 25, 2019
2 changes: 2 additions & 0 deletions backend/src/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func startRpcServer(resourceManager *resource.ResourceManager) {
api.RegisterRunServiceServer(s, server.NewRunServer(resourceManager))
api.RegisterJobServiceServer(s, server.NewJobServer(resourceManager))
api.RegisterReportServiceServer(s, server.NewReportServer(resourceManager))
api.RegisterVisualizationServiceServer(s, server.NewVisualizationServer(resourceManager))

// Register reflection service on gRPC server.
reflection.Register(s)
Expand All @@ -106,6 +107,7 @@ func startHttpProxy(resourceManager *resource.ResourceManager) {
registerHttpHandlerFromEndpoint(api.RegisterJobServiceHandlerFromEndpoint, "JobService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterRunServiceHandlerFromEndpoint, "RunService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterReportServiceHandlerFromEndpoint, "ReportService", ctx, mux)
registerHttpHandlerFromEndpoint(api.RegisterVisualizationServiceHandlerFromEndpoint, "Visualization", ctx, mux)

// Create a top level mux to include both pipeline upload server and gRPC servers.
topMux := http.NewServeMux()
Expand Down
2 changes: 2 additions & 0 deletions backend/src/apiserver/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ go_library(
"run_server.go",
"test_util.go",
"util.go",
"visualization_server.go",
],
importpath = "github.com/kubeflow/pipelines/backend/src/apiserver/server",
visibility = ["//visibility:public"],
Expand Down Expand Up @@ -50,6 +51,7 @@ go_test(
"run_metric_util_test.go",
"run_server_test.go",
"util_test.go",
"visualization_server_test.go",
],
data = glob(["test/**/*"]), # keep
embed = [":go_default_library"],
Expand Down
78 changes: 78 additions & 0 deletions backend/src/apiserver/server/visualization_server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package server

import (
"context"
"encoding/json"
"fmt"
"github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/kubeflow/pipelines/backend/src/apiserver/resource"
"github.com/kubeflow/pipelines/backend/src/common/util"
"io/ioutil"
"net/http"
"net/url"
"strings"
)

type VisualizationServer struct {
resourceManager *resource.ResourceManager
serviceURL string
}

func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *go_client.CreateVisualizationRequest) (*go_client.Visualization, error) {
if err := s.validateCreateVisualizationRequest(request); err != nil {
return nil, err
}
body, err := s.generateVisualizationFromRequest(request)
if err != nil {
return nil, err
}
request.Visualization.Html = string(body)
return request.Visualization, nil
}

// validateCreateVisualizationRequest ensures that a go_client.Visualization
// object has valid values.
// It returns an error if a go_client.Visualization object does not have valid
// values.
func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error {
if len(request.Visualization.InputPath) == 0 {
return util.NewInvalidInputError("A visualization requires an InputPath to be provided. Received %s", request.Visualization.InputPath)
}
// Manually set Arguments to empty JSON if nothing is provided. This is done
// because visualizations such as TFDV and TFMA only require an InputPath to
// be provided for a visualization to be generated. If no JSON is provided
// json.Valid will fail without this check as an empty string is provided for
// those visualizations.
if len(request.Visualization.Arguments) == 0 {
request.Visualization.Arguments = "{}"
}
if !json.Valid([]byte(request.Visualization.Arguments)) {
return util.NewInvalidInputError("A visualization requires valid JSON to be provided as Arguments. Received %s", request.Visualization.Arguments)
}
return nil
}

// generateVisualizationFromRequest communicates with the python visualization
// service to generate HTML visualizations from a request.
// It returns the generated HTML as a string and any error that is encountered.
func (s *VisualizationServer) generateVisualizationFromRequest(request *go_client.CreateVisualizationRequest) ([]byte, error) {
visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)])
arguments := fmt.Sprintf("--type %s --input_path %s --arguments '%s'", visualizationType, request.Visualization.InputPath, request.Visualization.Arguments)
resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}})
if err != nil {
return nil, util.Wrap(err, "Unable to initialize visualization request.")
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf(resp.Status)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, util.Wrap(err, "Unable to parse visualization response.")
}
return body, nil
}

func NewVisualizationServer(resourceManager *resource.ResourceManager) *VisualizationServer {
return &VisualizationServer{resourceManager: resourceManager, serviceURL: "http://visualization-service.kubeflow"}
}
117 changes: 117 additions & 0 deletions backend/src/apiserver/server/visualization_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package server

import (
"github.com/kubeflow/pipelines/backend/api/go_client"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)

func TestValidateCreateVisualizationRequest(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_InputPathIsEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Contains(t, err.Error(), "A visualization requires an InputPath to be provided. Received")
}

func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Contains(t, err.Error(), "A visualization requires valid JSON to be provided as Arguments. Received {")
}

func TestGenerateVisualization(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/", req.URL.String())
rw.Write([]byte("roc_curve"))
}))
defer httpServer.Close()
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
body, err := server.generateVisualizationFromRequest(request)
assert.Equal(t, []byte("roc_curve"), body)
assert.Nil(t, err)
}

func TestGenerateVisualization_ServerError(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
httpServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
assert.Equal(t, "/", req.URL.String())
rw.WriteHeader(500)
}))
defer httpServer.Close()
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
body, err := server.generateVisualizationFromRequest(request)
assert.Nil(t, body)
assert.Equal(t, "500 Internal Server Error", err.Error())
}