diff --git a/component_sdk/python/kfp_component/google/bigquery/_query.py b/component_sdk/python/kfp_component/google/bigquery/_query.py index 536ca2d5be3..398c6c15bd3 100644 --- a/component_sdk/python/kfp_component/google/bigquery/_query.py +++ b/component_sdk/python/kfp_component/google/bigquery/_query.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import logging from google.cloud import bigquery from google.api_core import exceptions @@ -20,8 +21,27 @@ from kfp_component.core import KfpExecutionContext, display from .. import common as gcp_common -def query(query, project_id, dataset_id, table_id=None, - output_gcs_path=None, job_config=None): +# TODO(hongyes): make this path configurable as a environment variable +KFP_OUTPUT_PATH = '/tmp/kfp/output/' + +def query(query, project_id, dataset_id=None, table_id=None, + output_gcs_path=None, dataset_location='US', job_config=None): + """Submit a query to Bigquery service and dump outputs to a GCS blob. + + Args: + query (str): The query used by Bigquery service to fetch the results. + project_id (str): The project to execute the query job. + dataset_id (str): The ID of the persistent dataset to keep the results + of the query. If the dataset does not exist, the operation will + create a new one. + table_id (str): The ID of the table to keep the results of the query. If + absent, the operation will generate a random id for the table. + output_gcs_path (str): The GCS blob path to dump the query results to. + dataset_location (str): The location to create the dataset. Defaults to `US`. + job_config (dict): The full config spec for the query job. + Returns: + The API representation of the completed query job. + """ client = bigquery.Client(project=project_id) if not job_config: job_config = bigquery.QueryJobConfig() @@ -33,21 +53,30 @@ def cancel(): client.cancel_job(job_id) with KfpExecutionContext(on_cancel=cancel) as ctx: job_id = 'query_' + ctx.context_id() - if not table_id: - table_id = 'table_' + ctx.context_id() - table_ref = client.dataset(dataset_id).table(table_id) query_job = _get_job(client, job_id) + table_ref = None if not query_job: - job_config.destination = table_ref + dataset_ref = _prepare_dataset_ref(client, dataset_id, output_gcs_path, + dataset_location) + if dataset_ref: + if not table_id: + table_id = job_id + table_ref = dataset_ref.table(table_id) + job_config.destination = table_ref query_job = client.query(query, job_config, job_id=job_id) _display_job_link(project_id, job_id) - query_job.result() + query_result = query_job.result() if output_gcs_path: job_id = 'extract_' + ctx.context_id() extract_job = _get_job(client, job_id) if not extract_job: extract_job = client.extract_table(table_ref, output_gcs_path) extract_job.result() # Wait for export to finish + else: + # Download results to local disk if no gcs output path. + gcp_common.dump_file(KFP_OUTPUT_PATH + 'bigquery/query_output.csv', + query_result.to_dataframe().to_csv()) + _dump_outputs(query_job, output_gcs_path) return query_job.to_api_repr() def _get_job(client, job_id): @@ -56,6 +85,30 @@ def _get_job(client, job_id): except exceptions.NotFound: return None +def _prepare_dataset_ref(client, dataset_id, output_gcs_path, dataset_location): + if not output_gcs_path and not dataset_id: + return None + + if not dataset_id: + dataset_id = 'kfp_tmp_dataset' + dataset_ref = client.dataset(dataset_id) + dataset = _get_dataset(client, dataset_ref) + if not dataset: + logging.info('Creating dataset {}'.format(dataset_id)) + dataset = _create_dataset(client, dataset_ref, dataset_location) + return dataset_ref + +def _get_dataset(client, dataset_ref): + try: + return client.get_dataset(dataset_ref) + except exceptions.NotFound: + return None + +def _create_dataset(client, dataset_ref, location): + dataset = bigquery.Dataset(dataset_ref) + dataset.location = location + return client.create_dataset(dataset) + def _display_job_link(project_id, job_id): display.display(display.Link( href= 'https://console.cloud.google.com/bigquery?project={}' @@ -63,6 +116,10 @@ def _display_job_link(project_id, job_id): text='Query Details' )) -def _dump_outputs(job): - gcp_common.dump_file('/tmp/outputs/bigquery-job.json', - json.dumps(job.to_api_repr())) \ No newline at end of file +def _dump_outputs(job, output_path): + gcp_common.dump_file(KFP_OUTPUT_PATH + 'biquery/query-job.json', + json.dumps(job.to_api_repr())) + if not output_path: + output_path = '' + gcp_common.dump_file(KFP_OUTPUT_PATH + 'biquery/query-output-path.txt', + output_path) diff --git a/component_sdk/python/tests/google/bigquery/test__query.py b/component_sdk/python/tests/google/bigquery/test__query.py index 06d91a42747..f663ae84861 100644 --- a/component_sdk/python/tests/google/bigquery/test__query.py +++ b/component_sdk/python/tests/google/bigquery/test__query.py @@ -25,10 +25,13 @@ @mock.patch(CREATE_JOB_MODULE + '.bigquery.Client') class TestQuery(unittest.TestCase): - def test_create_job_succeed(self, mock_client, + def test_query_succeed(self, mock_client, mock_kfp_context, mock_dump_json, mock_display): mock_kfp_context().__enter__().context_id.return_value = 'ctx1' mock_client().get_job.side_effect = exceptions.NotFound('not found') + mock_dataset = bigquery.DatasetReference('project-1', 'dataset-1') + mock_client().dataset.return_value = mock_dataset + mock_client().get_dataset.side_effect = exceptions.NotFound('not found') mock_response = { 'configuration': { 'query': { @@ -37,17 +40,16 @@ def test_create_job_succeed(self, mock_client, } } mock_client().query.return_value.to_api_repr.return_value = mock_response - mock_dataset = bigquery.DatasetReference('project-1', 'dataset-1') - mock_client().dataset.return_value = mock_dataset result = query('SELECT * FROM table_1', 'project-1', 'dataset-1', output_gcs_path='gs://output/path') self.assertEqual(mock_response, result) + mock_client().create_dataset.assert_called() expected_job_config = bigquery.QueryJobConfig() expected_job_config.create_disposition = bigquery.job.CreateDisposition.CREATE_IF_NEEDED expected_job_config.write_disposition = bigquery.job.WriteDisposition.WRITE_TRUNCATE - expected_job_config.destination = mock_dataset.table('table_ctx1') + expected_job_config.destination = mock_dataset.table('query_ctx1') mock_client().query.assert_called_with('SELECT * FROM table_1',mock.ANY, job_id = 'query_ctx1') actual_job_config = mock_client().query.call_args_list[0][0][1] @@ -56,6 +58,28 @@ def test_create_job_succeed(self, mock_client, actual_job_config.to_api_repr() ) mock_client().extract_table.assert_called_with( - mock_dataset.table('table_ctx1'), + mock_dataset.table('query_ctx1'), 'gs://output/path') + self.assertEqual(2, mock_dump_json.call_count) + + def test_query_dump_locally(self, mock_client, + mock_kfp_context, mock_dump_json, mock_display): + mock_kfp_context().__enter__().context_id.return_value = 'ctx1' + mock_client().get_job.side_effect = exceptions.NotFound('not found') + mock_response = { + 'configuration': { + 'query': { + 'query': 'SELECT * FROM table_1' + } + } + } + mock_client().query.return_value.to_api_repr.return_value = mock_response + + result = query('SELECT * FROM table_1', 'project-1') + + self.assertEqual(mock_response, result) + mock_client().create_dataset.assert_not_called() + mock_client().query.assert_called() + mock_client().extract_table.assert_not_called() + self.assertEqual(3, mock_dump_json.call_count)