Skip to content

Commit

Permalink
Add bigquery docstring and dump output path. (#885)
Browse files Browse the repository at this point in the history
* Add bigquery docstring and dump output path.

* Auto create dataset if it's not exist and dump results in local files

* make dataset location configurable

* Add todo to make kfp output path configurable.

* Fix comment
  • Loading branch information
hongye-sun authored Mar 6, 2019
1 parent 4e936f3 commit 2b07bb1
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 15 deletions.
77 changes: 67 additions & 10 deletions component_sdk/python/kfp_component/google/bigquery/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,35 @@
# limitations under the License.

import json
import logging

from google.cloud import bigquery
from google.api_core import exceptions

from kfp_component.core import KfpExecutionContext, display
from .. import common as gcp_common

def query(query, project_id, dataset_id, table_id=None,
output_gcs_path=None, job_config=None):
# TODO(hongyes): make this path configurable as a environment variable
KFP_OUTPUT_PATH = '/tmp/kfp/output/'

def query(query, project_id, dataset_id=None, table_id=None,
output_gcs_path=None, dataset_location='US', job_config=None):
"""Submit a query to Bigquery service and dump outputs to a GCS blob.
Args:
query (str): The query used by Bigquery service to fetch the results.
project_id (str): The project to execute the query job.
dataset_id (str): The ID of the persistent dataset to keep the results
of the query. If the dataset does not exist, the operation will
create a new one.
table_id (str): The ID of the table to keep the results of the query. If
absent, the operation will generate a random id for the table.
output_gcs_path (str): The GCS blob path to dump the query results to.
dataset_location (str): The location to create the dataset. Defaults to `US`.
job_config (dict): The full config spec for the query job.
Returns:
The API representation of the completed query job.
"""
client = bigquery.Client(project=project_id)
if not job_config:
job_config = bigquery.QueryJobConfig()
Expand All @@ -33,21 +53,30 @@ def cancel():
client.cancel_job(job_id)
with KfpExecutionContext(on_cancel=cancel) as ctx:
job_id = 'query_' + ctx.context_id()
if not table_id:
table_id = 'table_' + ctx.context_id()
table_ref = client.dataset(dataset_id).table(table_id)
query_job = _get_job(client, job_id)
table_ref = None
if not query_job:
job_config.destination = table_ref
dataset_ref = _prepare_dataset_ref(client, dataset_id, output_gcs_path,
dataset_location)
if dataset_ref:
if not table_id:
table_id = job_id
table_ref = dataset_ref.table(table_id)
job_config.destination = table_ref
query_job = client.query(query, job_config, job_id=job_id)
_display_job_link(project_id, job_id)
query_job.result()
query_result = query_job.result()
if output_gcs_path:
job_id = 'extract_' + ctx.context_id()
extract_job = _get_job(client, job_id)
if not extract_job:
extract_job = client.extract_table(table_ref, output_gcs_path)
extract_job.result() # Wait for export to finish
else:
# Download results to local disk if no gcs output path.
gcp_common.dump_file(KFP_OUTPUT_PATH + 'bigquery/query_output.csv',
query_result.to_dataframe().to_csv())
_dump_outputs(query_job, output_gcs_path)
return query_job.to_api_repr()

def _get_job(client, job_id):
Expand All @@ -56,13 +85,41 @@ def _get_job(client, job_id):
except exceptions.NotFound:
return None

def _prepare_dataset_ref(client, dataset_id, output_gcs_path, dataset_location):
if not output_gcs_path and not dataset_id:
return None

if not dataset_id:
dataset_id = 'kfp_tmp_dataset'
dataset_ref = client.dataset(dataset_id)
dataset = _get_dataset(client, dataset_ref)
if not dataset:
logging.info('Creating dataset {}'.format(dataset_id))
dataset = _create_dataset(client, dataset_ref, dataset_location)
return dataset_ref

def _get_dataset(client, dataset_ref):
try:
return client.get_dataset(dataset_ref)
except exceptions.NotFound:
return None

def _create_dataset(client, dataset_ref, location):
dataset = bigquery.Dataset(dataset_ref)
dataset.location = location
return client.create_dataset(dataset)

def _display_job_link(project_id, job_id):
display.display(display.Link(
href= 'https://console.cloud.google.com/bigquery?project={}'
'&j={}&page=queryresults'.format(project_id, job_id),
text='Query Details'
))

def _dump_outputs(job):
gcp_common.dump_file('/tmp/outputs/bigquery-job.json',
json.dumps(job.to_api_repr()))
def _dump_outputs(job, output_path):
gcp_common.dump_file(KFP_OUTPUT_PATH + 'biquery/query-job.json',
json.dumps(job.to_api_repr()))
if not output_path:
output_path = ''
gcp_common.dump_file(KFP_OUTPUT_PATH + 'biquery/query-output-path.txt',
output_path)
34 changes: 29 additions & 5 deletions component_sdk/python/tests/google/bigquery/test__query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
@mock.patch(CREATE_JOB_MODULE + '.bigquery.Client')
class TestQuery(unittest.TestCase):

def test_create_job_succeed(self, mock_client,
def test_query_succeed(self, mock_client,
mock_kfp_context, mock_dump_json, mock_display):
mock_kfp_context().__enter__().context_id.return_value = 'ctx1'
mock_client().get_job.side_effect = exceptions.NotFound('not found')
mock_dataset = bigquery.DatasetReference('project-1', 'dataset-1')
mock_client().dataset.return_value = mock_dataset
mock_client().get_dataset.side_effect = exceptions.NotFound('not found')
mock_response = {
'configuration': {
'query': {
Expand All @@ -37,17 +40,16 @@ def test_create_job_succeed(self, mock_client,
}
}
mock_client().query.return_value.to_api_repr.return_value = mock_response
mock_dataset = bigquery.DatasetReference('project-1', 'dataset-1')
mock_client().dataset.return_value = mock_dataset

result = query('SELECT * FROM table_1', 'project-1', 'dataset-1',
output_gcs_path='gs://output/path')

self.assertEqual(mock_response, result)
mock_client().create_dataset.assert_called()
expected_job_config = bigquery.QueryJobConfig()
expected_job_config.create_disposition = bigquery.job.CreateDisposition.CREATE_IF_NEEDED
expected_job_config.write_disposition = bigquery.job.WriteDisposition.WRITE_TRUNCATE
expected_job_config.destination = mock_dataset.table('table_ctx1')
expected_job_config.destination = mock_dataset.table('query_ctx1')
mock_client().query.assert_called_with('SELECT * FROM table_1',mock.ANY,
job_id = 'query_ctx1')
actual_job_config = mock_client().query.call_args_list[0][0][1]
Expand All @@ -56,6 +58,28 @@ def test_create_job_succeed(self, mock_client,
actual_job_config.to_api_repr()
)
mock_client().extract_table.assert_called_with(
mock_dataset.table('table_ctx1'),
mock_dataset.table('query_ctx1'),
'gs://output/path')
self.assertEqual(2, mock_dump_json.call_count)

def test_query_dump_locally(self, mock_client,
mock_kfp_context, mock_dump_json, mock_display):
mock_kfp_context().__enter__().context_id.return_value = 'ctx1'
mock_client().get_job.side_effect = exceptions.NotFound('not found')
mock_response = {
'configuration': {
'query': {
'query': 'SELECT * FROM table_1'
}
}
}
mock_client().query.return_value.to_api_repr.return_value = mock_response

result = query('SELECT * FROM table_1', 'project-1')

self.assertEqual(mock_response, result)
mock_client().create_dataset.assert_not_called()
mock_client().query.assert_called()
mock_client().extract_table.assert_not_called()
self.assertEqual(3, mock_dump_json.call_count)

0 comments on commit 2b07bb1

Please sign in to comment.