Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for custom visualizations #1951

17 changes: 12 additions & 5 deletions backend/src/apiserver/server/visualization_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *
// It returns an error if a go_client.Visualization object does not have valid
// values.
func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error {
if len(request.Visualization.Source) == 0 {
return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source)
// Only validate that a source is provided for non-custom visualizations.
if request.Visualization.Type != go_client.Visualization_CUSTOM {
if len(request.Visualization.Source) == 0 {
return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source)
}
}
// Manually set Arguments to empty JSON if nothing is provided. This is done
// because visualizations such as TFDV and TFMA only require an InputPath to
// because visualizations such as TFDV and TFMA only require a Source to
// be provided for a visualization to be generated. If no JSON is provided
// json.Valid will fail without this check as an empty string is provided for
// those visualizations.
Expand All @@ -67,8 +70,12 @@ func (s *VisualizationServer) generateVisualizationFromRequest(request *go_clien
)
}
visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)])
arguments := fmt.Sprintf("--type %s --source %s --arguments '%s'", visualizationType, request.Visualization.Source, request.Visualization.Arguments)
resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}})
urlValues := url.Values{
"arguments": {request.Visualization.Arguments},
"source": {request.Visualization.Source},
"type": {visualizationType},
}
resp, err := http.PostForm(s.serviceURL, urlValues)
if err != nil {
return nil, util.Wrap(err, "Unable to initialize visualization request.")
}
Expand Down
52 changes: 35 additions & 17 deletions backend/src/apiserver/server/visualization_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ func TestValidateCreateVisualizationRequest(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -31,12 +31,12 @@ func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -50,12 +50,12 @@ func TestValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "",
Source: "",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -65,16 +65,34 @@ func TestValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) {
assert.Contains(t, err.Error(), "A visualization requires a Source to be provided. Received")
}

func TestValidateCreateVisualizationRequest_SourceIsEmptyAndTypeIsCustom(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_CUSTOM,
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -93,13 +111,13 @@ func TestGenerateVisualization(t *testing.T) {
}))
defer httpServer.Close()
server := &VisualizationServer{
resourceManager: manager,
serviceURL: httpServer.URL,
resourceManager: manager,
serviceURL: httpServer.URL,
isServiceAvailable: true,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -119,13 +137,13 @@ func TestGenerateVisualization_ServiceNotAvailableError(t *testing.T) {
}))
defer httpServer.Close()
server := &VisualizationServer{
resourceManager: manager,
serviceURL: httpServer.URL,
resourceManager: manager,
serviceURL: httpServer.URL,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -145,13 +163,13 @@ func TestGenerateVisualization_ServerError(t *testing.T) {
}))
defer httpServer.Close()
server := &VisualizationServer{
resourceManager: manager,
serviceURL: httpServer.URL,
resourceManager: manager,
serviceURL: httpServer.URL,
isServiceAvailable: true,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
Source: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand Down
72 changes: 42 additions & 30 deletions backend/src/apiserver/visualization/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# limitations under the License.

from enum import Enum
import json
from pathlib import Path
from typing import Text
from jupyter_client import KernelManager
Expand All @@ -44,6 +43,48 @@ class TemplateType(Enum):
FULL = 'full'


def create_cell_from_args(variables: dict) -> NotebookNode:
"""Creates NotebookNode object containing dict of provided variables.

Args:
variables: Arguments that need to be injected into a NotebookNode.

Returns:
NotebookNode with provided arguments as variables.

"""
return new_code_cell("variables = {}".format(repr(variables)))


def create_cell_from_file(filepath: Text) -> NotebookNode:
"""Creates a NotebookNode object with provided file as code in node.

Args:
filepath: Path to file that should be used.

Returns:
NotebookNode with specified file as code within node.

"""
with open(filepath, 'r') as f:
code = f.read()

return new_code_cell(code)


def create_cell_from_custom_code(code: list) -> NotebookNode:
"""Creates a NotebookNode object with provided list as code in node.

Args:
code: list representing lines of code to be run.

Returns:
NotebookNode with specified file as code within node.

"""
return new_code_cell("\n".join(code))


class Exporter:
"""Handler for interaction with NotebookNodes, including output generation.

Expand Down Expand Up @@ -89,35 +130,6 @@ def __init__(
allow_errors=True
)

@staticmethod
def create_cell_from_args(variables: dict) -> NotebookNode:
"""Creates NotebookNode object containing dict of provided variables.

Args:
variables: Arguments that need to be injected into a NotebookNode.

Returns:
NotebookNode with provided arguments as variables.

"""
return new_code_cell("variables = {}".format(repr(variables)))

@staticmethod
def create_cell_from_file(filepath: Text) -> NotebookNode:
"""Creates a NotebookNode object with provided file as code in node.

Args:
filepath: Path to file that should be used.

Returns:
NotebookNode with specified file as code within node.

"""
with open(filepath, 'r') as f:
code = f.read()

return new_code_cell(code)

def generate_html_from_notebook(self, nb: NotebookNode) -> Text:
"""Converts a provided NotebookNode to HTML.

Expand Down
Loading