Skip to content

Commit

Permalink
feat: Upgrade BigQuery Datasource to use write() interface
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573042238
  • Loading branch information
matthew29tang authored and copybara-github committed Oct 12, 2023
1 parent 1ce9928 commit 7944348
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 72 deletions.
81 changes: 31 additions & 50 deletions google/cloud/aiplatform/preview/vertex_ray/bigquery_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import time
from typing import Any, Dict, List, Optional
import uuid
import pyarrow.parquet as pq

from google.api_core import client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as v1_client_info
from google.cloud import bigquery
from google.cloud import bigquery_storage
from google.cloud.aiplatform import initializer
from google.cloud.bigquery_storage import types
import pyarrow.parquet as pq
from ray.data._internal.remote_fn import cached_remote_fn

from ray.data._internal.execution.interfaces import TaskContext
from ray.data.block import Block
from ray.data.block import BlockAccessor
from ray.data.block import BlockMetadata
Expand All @@ -50,6 +50,9 @@
gapic_version=_BQS_GAPIC_VERSION, user_agent=f"ray-on-vertex/{_BQS_GAPIC_VERSION}"
)

MAX_RETRY_CNT = 10
RATE_LIMIT_EXCEEDED_SLEEP_TIME = 11


class _BigQueryDatasourceReader(Reader):
def __init__(
Expand All @@ -67,12 +70,12 @@ def __init__(

if query is not None and dataset is not None:
raise ValueError(
"[Ray on Vertex AI]: Query and dataset kwargs cannot both be provided (must be mutually exclusive)."
"[Ray on Vertex AI]: Query and dataset kwargs cannot both "
+ "be provided (must be mutually exclusive)."
)

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
# Executed by a worker node
def _read_single_partition(stream, kwargs) -> Block:
def _read_single_partition(stream) -> Block:
client = bigquery_storage.BigQueryReadClient(client_info=bqstorage_info)
reader = client.read_rows(stream.name)
return reader.to_arrow()
Expand All @@ -96,9 +99,9 @@ def _read_single_partition(stream, kwargs) -> Block:

if parallelism == -1:
parallelism = None
requested_session = types.ReadSession(
requested_session = bigquery_storage.types.ReadSession(
table=table,
data_format=types.DataFormat.ARROW,
data_format=bigquery_storage.types.DataFormat.ARROW,
)
read_session = bqs_client.create_read_session(
parent=f"projects/{self._project_id}",
Expand All @@ -107,9 +110,9 @@ def _read_single_partition(stream, kwargs) -> Block:
)

read_tasks = []
print("[Ray on Vertex AI]: Created streams:", len(read_session.streams))
logging.info(f"Created streams: {len(read_session.streams)}")
if len(read_session.streams) < parallelism:
print(
logging.info(
"[Ray on Vertex AI]: The number of streams created by the "
+ "BigQuery Storage Read API is less than the requested "
+ "parallelism due to the size of the dataset."
Expand All @@ -125,15 +128,11 @@ def _read_single_partition(stream, kwargs) -> Block:
exec_stats=None,
)

# Create a no-arg wrapper read function which returns a block
read_single_partition = (
lambda stream=stream, kwargs=self._kwargs: [ # noqa: F731
_read_single_partition(stream, kwargs)
]
# Create the read task and pass the no-arg wrapper and metadata in
read_task = ReadTask(
lambda stream=stream: [_read_single_partition(stream)],
metadata,
)

# Create the read task and pass the wrapper and metadata in
read_task = ReadTask(read_single_partition, metadata)
read_tasks.append(read_task)

return read_tasks
Expand Down Expand Up @@ -168,18 +167,14 @@ class BigQueryDatasource(Datasource):
def create_reader(self, **kwargs) -> Reader:
return _BigQueryDatasourceReader(**kwargs)

def do_write(
def write(
self,
blocks: List[ObjectRef[Block]],
metadata: List[BlockMetadata],
ray_remote_args: Optional[Dict[str, Any]],
ctx: TaskContext,
project_id: Optional[str] = None,
dataset: Optional[str] = None,
) -> List[ObjectRef[WriteResult]]:
def _write_single_block(
block: Block, metadata: BlockMetadata, project_id: str, dataset: str
):
print("[Ray on Vertex AI]: Starting to write", metadata.num_rows, "rows")
) -> WriteResult:
def _write_single_block(block: Block, project_id: str, dataset: str):
block = BlockAccessor.for_block(block).to_arrow()

client = bigquery.Client(project=project_id, client_info=bq_info)
Expand All @@ -192,7 +187,7 @@ def _write_single_block(
pq.write_table(block, fp, compression="SNAPPY")

retry_cnt = 0
while retry_cnt < 10:
while retry_cnt < MAX_RETRY_CNT:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
Expand All @@ -202,12 +197,11 @@ def _write_single_block(
logging.info(job.result())
break
except exceptions.Forbidden as e:
print(
logging.info(
"[Ray on Vertex AI]: Rate limit exceeded... Sleeping to try again"
)
logging.debug(e)
time.sleep(11)
print("[Ray on Vertex AI]: Finished writing", metadata.num_rows, "rows")
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)

project_id = project_id or initializer.global_config.project

Expand All @@ -216,34 +210,21 @@ def _write_single_block(
"[Ray on Vertex AI]: Dataset is required when writing to BigQuery."
)

if ray_remote_args is None:
ray_remote_args = {}

_write_single_block = cached_remote_fn(_write_single_block).options(
**ray_remote_args
)
write_tasks = []

# Set up datasets to write
client = bigquery.Client(project=project_id, client_info=bq_info)
dataset_id = dataset.split(".", 1)[0]
try:
client.create_dataset(f"{project_id}.{dataset_id}", timeout=30)
print("[Ray on Vertex AI]: Created dataset", dataset_id)
logging.info(f"[Ray on Vertex AI]: Created dataset {dataset_id}.")
except exceptions.Conflict:
print(
"[Ray on Vertex AI]: Dataset",
dataset_id,
"already exists. The table will be overwritten if it already exists.",
logging.info(
f"[Ray on Vertex AI]: Dataset {dataset_id} already exists. "
+ "The table will be overwritten if it already exists."
)

# Delete table if it already exists
client.delete_table(f"{project_id}.{dataset}", not_found_ok=True)

print("[Ray on Vertex AI]: Writing", len(blocks), "blocks")
for i in range(len(blocks)):
write_task = _write_single_block.remote(
blocks[i], metadata[i], project_id, dataset
)
write_tasks.append(write_task)
return write_tasks
for block in blocks:
_write_single_block(block, project_id, dataset)
return "ok"
44 changes: 22 additions & 22 deletions tests/unit/vertex_ray/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud.bigquery import job
from google.cloud.bigquery_storage_v1.types import stream as gcbqs_stream
import mock
import pyarrow as pa
import pytest
import ray

Expand Down Expand Up @@ -89,7 +90,6 @@ def bq_query_mock(query):
client_mock.query = bq_query_mock

monkeypatch.setattr(bigquery, "Client", client_mock)
client_mock.reset_mock()
return client_mock


Expand All @@ -108,7 +108,6 @@ def bqs_create_read_session(max_stream_count=0, **kwargs):
client_mock.create_read_session = bqs_create_read_session

monkeypatch.setattr(bigquery_storage, "BigQueryReadClient", client_mock)
client_mock.reset_mock()
return client_mock


Expand Down Expand Up @@ -259,16 +258,16 @@ def setup_method(self):
def teardown_method(self):
aiplatform.initializer.global_pool.shutdown(wait=True)

def test_do_write(self, ray_remote_function_mock):
def test_write(self):
bq_ds = bigquery_datasource.BigQueryDatasource()
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
dataset=_TEST_BQ_DATASET,
)
assert len(write_tasks_list) == 4
assert status == "ok"

def test_do_write_initialized(self, ray_remote_function_mock):
"""If initialized, do_write doesn't need to specify project_id."""
Expand All @@ -277,21 +276,22 @@ def test_do_write_initialized(self, ray_remote_function_mock):
staging_bucket=tc.ProjectConstants._TEST_ARTIFACT_URI,
)
bq_ds = bigquery_datasource.BigQueryDatasource()
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
dataset=_TEST_BQ_DATASET,
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
)
assert len(write_tasks_list) == 4
assert status == "ok"

def test_do_write_dataset_exists(self, ray_remote_function_mock):
def test_write_dataset_exists(self, ray_remote_function_mock):
bq_ds = bigquery_datasource.BigQueryDatasource()
write_tasks_list = bq_ds.do_write(
blocks=[1, 2, 3, 4],
metadata=[1, 2, 3, 4],
ray_remote_args={},
project_id=tc.ProjectConstants._TEST_GCP_PROJECT_ID,
arr = pa.array([2, 4, 5, 100])
block = pa.Table.from_arrays([arr], names=["data"])
status = bq_ds.write(
blocks=[block],
ctx=None,
dataset="existingdataset" + "." + _TEST_BQ_TABLE_ID,
)
assert len(write_tasks_list) == 4
assert status == "ok"

0 comments on commit 7944348

Please sign in to comment.