From f05a1d1b7657fb16980d320e45e59f01d078b5ac Mon Sep 17 00:00:00 2001 From: Jiaxin Shan Date: Sat, 4 May 2019 22:48:04 -0700 Subject: [PATCH] Make confusion_matrix and roc generic --- .../local/confusion_matrix/src/confusion_matrix.py | 7 ++++--- components/local/roc/src/roc.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/components/local/confusion_matrix/src/confusion_matrix.py b/components/local/confusion_matrix/src/confusion_matrix.py index 05382e3a8c9..636c1db1a91 100644 --- a/components/local/confusion_matrix/src/confusion_matrix.py +++ b/components/local/confusion_matrix/src/confusion_matrix.py @@ -25,6 +25,7 @@ import argparse import json import os +import urlparse import pandas as pd from sklearn.metrics import confusion_matrix, accuracy_score from tensorflow.python.lib.io import file_io @@ -40,7 +41,8 @@ def main(argv=None): 'If not set, the input must include a "target" column.') args = parser.parse_args() - on_cloud = args.output.startswith('gs://') + storage_service_scheme = urlparse.urlparse(args.output).scheme + on_cloud = True if storage_service_scheme else False if not on_cloud and not os.path.exists(args.output): os.makedirs(args.output) @@ -52,7 +54,7 @@ def main(argv=None): for file in files: with file_io.FileIO(file, 'r') as f: dfs.append(pd.read_csv(f, names=names)) - + df = pd.concat(dfs) if args.target_lambda: df['target'] = df.apply(eval(args.target_lambda), axis=1) @@ -72,7 +74,6 @@ def main(argv=None): metadata = { 'outputs' : [{ 'type': 'confusion_matrix', - 'storage': 'gcs', 'format': 'csv', 'schema': [ {'name': 'target', 'type': 'CATEGORY'}, diff --git a/components/local/roc/src/roc.py b/components/local/roc/src/roc.py index 5de330787ef..b67f25e5264 100644 --- a/components/local/roc/src/roc.py +++ b/components/local/roc/src/roc.py @@ -24,6 +24,7 @@ import argparse import json import os +import urlparse import pandas as pd from sklearn.metrics import roc_curve, roc_auc_score from tensorflow.python.lib.io import file_io @@ -45,7 +46,8 @@ def main(argv=None): parser.add_argument('--output', type=str, help='GCS path of the output directory.') args = parser.parse_args() - on_cloud = args.output.startswith('gs://') + storage_service_scheme = urlparse.urlparse(args.output).scheme + on_cloud = True if storage_service_scheme else False if not on_cloud and not os.path.exists(args.output): os.makedirs(args.output) @@ -64,7 +66,7 @@ def main(argv=None): for file in files: with file_io.FileIO(file, 'r') as f: dfs.append(pd.read_csv(f, names=names)) - + df = pd.concat(dfs) if args.target_lambda: df['target'] = df.apply(eval(args.target_lambda), axis=1) @@ -76,11 +78,10 @@ def main(argv=None): roc_file = os.path.join(args.output, 'roc.csv') with file_io.FileIO(roc_file, 'w') as f: df_roc.to_csv(f, columns=['fpr', 'tpr', 'thresholds'], header=False, index=False) - + metadata = { 'outputs': [{ 'type': 'roc', - 'storage': 'gcs', 'format': 'csv', 'schema': [ {'name': 'fpr', 'type': 'NUMBER'},