Skip to content

Commit

Permalink
Rename InputPath -> Source for Visualization API definition (#1717)
Browse files Browse the repository at this point in the history
* InputPath -> Source

* Changed name of data path/pattern variable from InputPath to Source to improve consistency with current visualization method
* Updated unit tests to reflect name change
* Regenerated swagger definitions to reflect name change

* Readded test that was removed with previous commit

It was deleted by mistake
  • Loading branch information
ajchili authored and k8s-ci-robot committed Aug 5, 2019
1 parent 44f8198 commit fa1abde
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 106 deletions.
76 changes: 38 additions & 38 deletions backend/api/go_client/visualization.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/api/swagger/visualization.swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backend/api/visualization.proto
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ message Visualization {
// Path pattern of input data to be used during generation of visualizations.
// This is required when creating the pipeline through CreateVisualization
// API.
string inputPath = 2;
string source = 2;

// Variables to be used during generation of a visualization.
// This should be provided as a JSON string.
Expand Down
6 changes: 3 additions & 3 deletions backend/src/apiserver/server/visualization_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func (s *VisualizationServer) CreateVisualization(ctx context.Context, request *
// It returns an error if a go_client.Visualization object does not have valid
// values.
func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_client.CreateVisualizationRequest) error {
if len(request.Visualization.InputPath) == 0 {
return util.NewInvalidInputError("A visualization requires an InputPath to be provided. Received %s", request.Visualization.InputPath)
if len(request.Visualization.Source) == 0 {
return util.NewInvalidInputError("A visualization requires a Source to be provided. Received %s", request.Visualization.Source)
}
// Manually set Arguments to empty JSON if nothing is provided. This is done
// because visualizations such as TFDV and TFMA only require an InputPath to
Expand All @@ -57,7 +57,7 @@ func (s *VisualizationServer) validateCreateVisualizationRequest(request *go_cli
// It returns the generated HTML as a string and any error that is encountered.
func (s *VisualizationServer) generateVisualizationFromRequest(request *go_client.CreateVisualizationRequest) ([]byte, error) {
visualizationType := strings.ToLower(go_client.Visualization_Type_name[int32(request.Visualization.Type)])
arguments := fmt.Sprintf("--type %s --input_path %s --arguments '%s'", visualizationType, request.Visualization.InputPath, request.Visualization.Arguments)
arguments := fmt.Sprintf("--type %s --source %s --arguments '%s'", visualizationType, request.Visualization.Source, request.Visualization.Arguments)
resp, err := http.PostForm(s.serviceURL, url.Values{"arguments": {arguments}})
if err != nil {
return nil, util.Wrap(err, "Unable to initialize visualization request.")
Expand Down
16 changes: 8 additions & 8 deletions backend/src/apiserver/server/visualization_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestValidateCreateVisualizationRequest(t *testing.T) {
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -30,7 +30,7 @@ func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) {
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -40,20 +40,20 @@ func TestValidateCreateVisualizationRequest_ArgumentsAreEmpty(t *testing.T) {
assert.Nil(t, err)
}

func TestValidateCreateVisualizationRequest_InputPathIsEmpty(t *testing.T) {
func TestValidateCreateVisualizationRequest_SourceIsEmpty(t *testing.T) {
clients, manager, _ := initWithExperiment(t)
defer clients.Close()
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "",
Source: "",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Visualization: visualization,
}
err := server.validateCreateVisualizationRequest(request)
assert.Contains(t, err.Error(), "A visualization requires an InputPath to be provided. Received")
assert.Contains(t, err.Error(), "A visualization requires a Source to be provided. Received")
}

func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T) {
Expand All @@ -62,7 +62,7 @@ func TestValidateCreateVisualizationRequest_ArgumentsNotValidJSON(t *testing.T)
server := NewVisualizationServer(manager)
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -83,7 +83,7 @@ func TestGenerateVisualization(t *testing.T) {
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand All @@ -105,7 +105,7 @@ func TestGenerateVisualization_ServerError(t *testing.T) {
server := &VisualizationServer{resourceManager: manager, serviceURL: httpServer.URL}
visualization := &go_client.Visualization{
Type: go_client.Visualization_ROC_CURVE,
InputPath: "gs://ml-pipeline/roc/data.csv",
Source: "gs://ml-pipeline/roc/data.csv",
Arguments: "{}",
}
request := &go_client.CreateVisualizationRequest{
Expand Down
8 changes: 4 additions & 4 deletions backend/src/apiserver/visualization/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# API post request.
#
# is_generated
# input_path
# source
# target_lambda
# trueclass
# true_score_column
Expand All @@ -38,12 +38,12 @@
# Create data from specified csv file(s).
# The schema file provides column names for the csv file that will be used
# to generate the roc curve.
schema_file = Path(input_path) / 'schema.json'
schema_file = Path(source) / 'schema.json'
schema = json.loads(file_io.read_file_to_string(schema_file))
names = [x['name'] for x in schema]

dfs = []
files = file_io.get_matching_files(input_path)
files = file_io.get_matching_files(source)
for f in files:
dfs.append(pd.read_csv(f, names=names))

Expand All @@ -57,7 +57,7 @@
else:
# Load data from generated csv file.
source = pd.read_csv(
input_path,
source,
header=None,
names=['fpr', 'tpr', 'thresholds']
)
Expand Down
12 changes: 6 additions & 6 deletions backend/src/apiserver/visualization/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def initialize(self):
)
# Path of data to be used to generate visualization.
self.requestParser.add_argument(
"--input_path",
"--source",
type=str,
help="Path of data to be used for generating visualization."
)
Expand Down Expand Up @@ -98,16 +98,16 @@ def is_valid_request_arguments(self, arguments: argparse.Namespace) -> bool:
if arguments.type is None:
self.send_error(400, reason="No type specified.")
return False
if arguments.input_path is None:
self.send_error(400, reason="No input_path specified.")
if arguments.source is None:
self.send_error(400, reason="No source specified.")
return False

return True

def generate_notebook_from_arguments(
self,
arguments: argparse.Namespace,
input_path: Text,
source: Text,
visualization_type: Text
) -> NotebookNode:
"""Generates a NotebookNode from provided arguments.
Expand All @@ -123,7 +123,7 @@ def generate_notebook_from_arguments(
"""
nb = new_notebook()
nb.cells.append(_exporter.create_cell_from_args(arguments))
nb.cells.append(new_code_cell('input_path = "{}"'.format(input_path)))
nb.cells.append(new_code_cell('source = "{}"'.format(source)))
visualization_file = str(Path.cwd() / "{}.py".format(visualization_type))
nb.cells.append(_exporter.create_cell_from_file(visualization_file))
return nb
Expand All @@ -143,7 +143,7 @@ def post(self):
# Create notebook with arguments from request.
nb = self.generate_notebook_from_arguments(
request_arguments.arguments,
request_arguments.input_path,
request_arguments.source,
request_arguments.type
)
# Generate visualization (output for notebook).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

snapshots = Snapshot()

snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''input_path = "gs://ml-pipeline/data.csv"
snapshots['TestExporterMethods::test_create_cell_from_args_with_multiple_args 1'] = '''source = "gs://ml-pipeline/data.csv"
target_lambda = "lambda x: (x[\'target\'] > x[\'fare\'] * 0.2)"
'''

snapshots['TestExporterMethods::test_create_cell_from_args_with_no_args 1'] = ''

snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''input_path = "gs://ml-pipeline/data.csv"
snapshots['TestExporterMethods::test_create_cell_from_args_with_one_arg 1'] = '''source = "gs://ml-pipeline/data.csv"
'''

snapshots['TestExporterMethods::test_create_cell_from_file 1'] = '''# Copyright 2019 Google LLC
Expand All @@ -36,9 +36,9 @@
# variables come from the specified input path and arguments provided by the
# API post request.
#
# input_path
# source
train_stats = tfdv.generate_statistics_from_csv(data_location=input_path)
train_stats = tfdv.generate_statistics_from_csv(data_location=source)
tfdv.visualize_statistics(train_stats)
'''
Expand Down
4 changes: 2 additions & 2 deletions backend/src/apiserver/visualization/test_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def test_create_cell_from_args_with_no_args(self):

def test_create_cell_from_args_with_one_arg(self):
self.maxDiff = None
args = '{"input_path": "gs://ml-pipeline/data.csv"}'
args = '{"source": "gs://ml-pipeline/data.csv"}'
cell = self.exporter.create_cell_from_args(args)
self.assertMatchSnapshot(cell.source)

def test_create_cell_from_args_with_multiple_args(self):
self.maxDiff = None
args = (
'{"input_path": "gs://ml-pipeline/data.csv", '
'{"source": "gs://ml-pipeline/data.csv", '
"\"target_lambda\": \"lambda x: (x['target'] > x['fare'] * 0.2)\"}"
)
cell = self.exporter.create_cell_from_args(args)
Expand Down
15 changes: 9 additions & 6 deletions backend/src/apiserver/visualization/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ def test_create_visualization_fails_when_nothing_is_provided(self):
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: Bad Request"),
response.body)
response.body
)

def test_create_visualization_fails_when_missing_type(self):
response = self.fetch(
"/",
method="POST",
body="arguments=--input_path gs://ml-pipeline/data.csv")
body="arguments=--source gs://ml-pipeline/data.csv")
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: No type specified."),
response.body)
response.body
)

def test_create_visualization_fails_when_missing_input_path(self):
response = self.fetch(
Expand All @@ -64,14 +66,15 @@ def test_create_visualization_fails_when_missing_input_path(self):
body='arguments=--type test')
self.assertEqual(400, response.code)
self.assertEqual(
wrap_error_in_html("400: No input_path specified."),
response.body)
wrap_error_in_html("400: No source specified."),
response.body
)

def test_create_visualization(self):
response = self.fetch(
"/",
method="POST",
body='arguments=--type test --input_path gs://ml-pipeline/data.csv')
body='arguments=--type test --source gs://ml-pipeline/data.csv')
self.assertEqual(200, response.code)


Expand Down
Loading

0 comments on commit fa1abde

Please sign in to comment.