Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for custom visualizations #1951

11 changes: 9 additions & 2 deletions backend/src/apiserver/server/visualization_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *
// It returns an error if a go_client.Visualization object does not have valid
// values.
func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error {
if len(request.Visualization.Source) == 0 {
// Only validate that a source is provided for non-custom visualizations.
if request.Visualization.Type != go_client.Visualization_CUSTOM && len(request.Visualization.Source) == 0 {
return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source)
}
// Manually set Arguments to empty JSON if nothing is provided. This is done
Expand Down Expand Up @@ -67,7 +68,13 @@ func (s *VisualizationServer) generateVisualizationFromRequest(request *go_clien
)
}
visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)])
arguments := fmt.Sprintf("--type %s --source %s --arguments '%s'", visualizationType, request.Visualization.Source, request.Visualization.Arguments)
arguments := fmt.Sprintf("--type %s --arguments '''%s'''", visualizationType, request.Visualization.Arguments)
ajchili marked this conversation as resolved.
Show resolved Hide resolved
if !(request.Visualization.Type == go_client.Visualization_CUSTOM && len(request.Visualization.Source) == 0) {
ajchili marked this conversation as resolved.
Show resolved Hide resolved
// Only add the source argument if a visualization is one of the following:
// - Not a custom visualization
// - A custom visualization that has a source specified
arguments += fmt.Sprintf(" --source %s", request.Visualization.Source)
}
resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}})
if err != nil {
return nil, util.Wrap(err, "Unable to initialize visualization request.")
Expand Down
18 changes: 18 additions & 0 deletions backend/src/apiserver/server/visualization_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ func TestValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) {
assert.Contains(t, err.Error(), "A visualization requires a Source to be provided. Received")
}

func TestValidateCreateVisualizationRequest_SourceIsEmptyAndTypeIsCustom(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := &VisualizationServer{
resourceManager: manager,
isServiceAvailable: false,
}
visualization := &go_client.Visualization{
Type: go_client.Visualization_CUSTOM,
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
Expand Down
13 changes: 13 additions & 0 deletions backend/src/apiserver/visualization/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ def create_cell_from_file(filepath: Text) -> NotebookNode:

return new_code_cell(code)

@staticmethod
def create_cell_from_custom_code(code: list) -> NotebookNode:
ajchili marked this conversation as resolved.
Show resolved Hide resolved
"""Creates a NotebookNode object with provided list as code in node.

Args:
code: list representing lines of code to be run.

Returns:
NotebookNode with specified file as code within node.

"""
return new_code_cell("\n".join(code))

def generate_html_from_notebook(self, nb: NotebookNode) -> Text:
"""Converts a provided NotebookNode to HTML.

Expand Down
11 changes: 8 additions & 3 deletions backend/src/apiserver/visualization/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_valid_request_arguments(self, arguments: Namespace):
"""
if arguments.type is None:
raise Exception("No type specified.")
if arguments.source is None:
if arguments.type != "custom" and arguments.source is None:
raise Exception("No source specified.")
try:
json.loads(arguments.arguments)
Expand Down Expand Up @@ -126,8 +126,13 @@ def generate_notebook_from_arguments(
nb = new_notebook()
nb.cells.append(_exporter.create_cell_from_args(arguments))
nb.cells.append(new_code_cell('source = "{}"'.format(source)))
visualization_file = str(Path.cwd() / "types/{}.py".format(visualization_type))
nb.cells.append(_exporter.create_cell_from_file(visualization_file))
if visualization_type == "custom":
code = arguments.get("code", [])
nb.cells.append(_exporter.create_cell_from_custom_code(code))
else:
visualization_file = str(Path.cwd() / "types/{}.py".format(visualization_type))
ajchili marked this conversation as resolved.
Show resolved Hide resolved
nb.cells.append(_exporter.create_cell_from_file(visualization_file))

return nb

def get(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,6 @@


'''

snapshots['TestExporterMethods::test_create_cell_from_custom_code 1'] = '''x = 2
print(x)'''
9 changes: 9 additions & 0 deletions backend/src/apiserver/visualization/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def test_create_cell_from_file(self):
cell = self.exporter.create_cell_from_file("types/test.py")
self.assertMatchSnapshot(cell.source)

def test_create_cell_from_custom_code(self):
self.maxDiff = None
ajchili marked this conversation as resolved.
Show resolved Hide resolved
code = [
"x = 2",
"print(x)"
]
cell = self.exporter.create_cell_from_custom_code(code)
self.assertMatchSnapshot(cell.source)

def test_generate_html_from_notebook(self):
self.maxDiff = None
nb = new_notebook()
Expand Down
7 changes: 7 additions & 0 deletions backend/src/apiserver/visualization/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def test_create_visualization_fails_when_missing_input_path(self):
response.body
)

def test_create_visualization_passes_when_missing_input_path_and_type_is_custom(self):
response = self.fetch(
"/",
method="POST",
body='arguments=--type custom')
self.assertEqual(200, response.code)

def test_create_visualization_fails_when_invalid_json_is_provided(self):
response = self.fetch(
"/",
Expand Down