Skip to content

Commit

Permalink
feat: Support custom batch size for Bigframes Tensorflow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589190954
  • Loading branch information
matthew29tang authored and copybara-github committed Dec 8, 2023
1 parent 0cb1a7b commit 7dc8771
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/system/vertexai/test_bigframes_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources"
)
class TestRemoteExecutionBigframesTensorflow(e2e_base.TestEndToEnd):

_temp_prefix = "temp-vertexai-remote-execution"

def test_remote_execution_keras(self, shared_state):
Expand Down Expand Up @@ -97,6 +96,7 @@ def test_remote_execution_keras(self, shared_state):
enable_cuda=True,
display_name=self._make_display_name("bigframes-keras-training"),
)
model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10}

# Train model on Vertex
model.fit(train, epochs=10)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/vertexai/test_any_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow(
mock_bigframe_deserialize_tensorflow.assert_called_once_with(
any_serializer_instance._instances[serializers.BigframeSerializer],
serialized_gcs_path=fake_gcs_path,
batch_size=None,
)

def test_any_serializer_deserialize_tf_dataset(
Expand Down
21 changes: 14 additions & 7 deletions vertexai/preview/_workflow/serialization_engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"2.12": "0.32.0",
"2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13
}
DEFAULT_TENSORFLOW_BATCHSIZE = 32


def get_uri_prefix(gcs_uri: str) -> str:
Expand Down Expand Up @@ -1174,7 +1175,9 @@ def serialize(
# Convert bigframes.dataframe.DataFrame to Parquet (GCS)
parquet_gcs_path = gcs_path + "/*" # path is required to contain '*'
to_serialize.to_parquet(parquet_gcs_path, index=True)
return parquet_gcs_path

# Return original gcs_path to retrieve the metadata for later
return gcs_path

def _get_tfio_verison(self):
major, minor, _ = version.Version(tf.__version__).release
Expand All @@ -1190,15 +1193,15 @@ def _get_tfio_verison(self):
def deserialize(
self, serialized_gcs_path: str, **kwargs
) -> Union["pandas.DataFrame", "bigframes.dataframe.DataFrame"]: # noqa: F821
del kwargs

detected_framework = BigframeSerializer._metadata.framework
if detected_framework == "sklearn":
return self._deserialize_sklearn(serialized_gcs_path)
elif detected_framework == "torch":
return self._deserialize_torch(serialized_gcs_path)
elif detected_framework == "tensorflow":
return self._deserialize_tensorflow(serialized_gcs_path)
return self._deserialize_tensorflow(
serialized_gcs_path, kwargs.get("batch_size")
)
else:
raise ValueError(f"Unsupported framework: {detected_framework}")

Expand Down Expand Up @@ -1269,11 +1272,16 @@ def reduce_tensors(a, b):

return functools.reduce(reduce_tensors, list(parquet_df_dp))

def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset:
def _deserialize_tensorflow(
self, serialized_gcs_path: str, batch_size: Optional[int] = None
) -> TFDataset:
"""Tensorflow deserializes parquet (GCS) --> tf.data.Dataset
serialized_gcs_path is a folder containing one or more parquet files.
"""
# Set default batch_size
batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE

# Deserialization at remote environment
try:
import tensorflow_io as tfio
Expand Down Expand Up @@ -1307,8 +1315,7 @@ def reduce_fn(a, b):

return functools.reduce(reduce_fn, row.values()), target

# TODO(b/295535730): Remove hardcoded batch_size of 32
return ds.map(map_fn).batch(32)
return ds.map(map_fn).batch(batch_size)


class CloudPickleSerializer(serializers_base.Serializer):
Expand Down

0 comments on commit 7dc8771

Please sign in to comment.