From 7dc8771715387e4fb4536aa7080b795bfba8039c Mon Sep 17 00:00:00 2001 From: Matthew Tang Date: Fri, 8 Dec 2023 11:19:34 -0800 Subject: [PATCH] feat: Support custom batch size for Bigframes Tensorflow PiperOrigin-RevId: 589190954 --- .../vertexai/test_bigframes_tensorflow.py | 2 +- tests/unit/vertexai/test_any_serializer.py | 1 + .../serialization_engine/serializers.py | 21 ++++++++++++------- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/system/vertexai/test_bigframes_tensorflow.py b/tests/system/vertexai/test_bigframes_tensorflow.py index da64e6abab..22ecd068ec 100644 --- a/tests/system/vertexai/test_bigframes_tensorflow.py +++ b/tests/system/vertexai/test_bigframes_tensorflow.py @@ -62,7 +62,6 @@ "prepare_staging_bucket", "delete_staging_bucket", "tear_down_resources" ) class TestRemoteExecutionBigframesTensorflow(e2e_base.TestEndToEnd): - _temp_prefix = "temp-vertexai-remote-execution" def test_remote_execution_keras(self, shared_state): @@ -97,6 +96,7 @@ def test_remote_execution_keras(self, shared_state): enable_cuda=True, display_name=self._make_display_name("bigframes-keras-training"), ) + model.fit.vertex.remote_config.serializer_args[train] = {"batch_size": 10} # Train model on Vertex model.fit(train, epochs=10) diff --git a/tests/unit/vertexai/test_any_serializer.py b/tests/unit/vertexai/test_any_serializer.py index d675634438..b46444f44c 100644 --- a/tests/unit/vertexai/test_any_serializer.py +++ b/tests/unit/vertexai/test_any_serializer.py @@ -1105,6 +1105,7 @@ def test_any_serializer_deserialize_bigframe_tensorflow( mock_bigframe_deserialize_tensorflow.assert_called_once_with( any_serializer_instance._instances[serializers.BigframeSerializer], serialized_gcs_path=fake_gcs_path, + batch_size=None, ) def test_any_serializer_deserialize_tf_dataset( diff --git a/vertexai/preview/_workflow/serialization_engine/serializers.py b/vertexai/preview/_workflow/serialization_engine/serializers.py index b3148a64ff..5161248e07 100644 --- a/vertexai/preview/_workflow/serialization_engine/serializers.py +++ b/vertexai/preview/_workflow/serialization_engine/serializers.py @@ -90,6 +90,7 @@ "2.12": "0.32.0", "2.13": "0.34.0", # TODO(b/295580335): Support TF 2.13 } +DEFAULT_TENSORFLOW_BATCHSIZE = 32 def get_uri_prefix(gcs_uri: str) -> str: @@ -1174,7 +1175,9 @@ def serialize( # Convert bigframes.dataframe.DataFrame to Parquet (GCS) parquet_gcs_path = gcs_path + "/*" # path is required to contain '*' to_serialize.to_parquet(parquet_gcs_path, index=True) - return parquet_gcs_path + + # Return original gcs_path to retrieve the metadata for later + return gcs_path def _get_tfio_verison(self): major, minor, _ = version.Version(tf.__version__).release @@ -1190,15 +1193,15 @@ def _get_tfio_verison(self): def deserialize( self, serialized_gcs_path: str, **kwargs ) -> Union["pandas.DataFrame", "bigframes.dataframe.DataFrame"]: # noqa: F821 - del kwargs - detected_framework = BigframeSerializer._metadata.framework if detected_framework == "sklearn": return self._deserialize_sklearn(serialized_gcs_path) elif detected_framework == "torch": return self._deserialize_torch(serialized_gcs_path) elif detected_framework == "tensorflow": - return self._deserialize_tensorflow(serialized_gcs_path) + return self._deserialize_tensorflow( + serialized_gcs_path, kwargs.get("batch_size") + ) else: raise ValueError(f"Unsupported framework: {detected_framework}") @@ -1269,11 +1272,16 @@ def reduce_tensors(a, b): return functools.reduce(reduce_tensors, list(parquet_df_dp)) - def _deserialize_tensorflow(self, serialized_gcs_path: str) -> TFDataset: + def _deserialize_tensorflow( + self, serialized_gcs_path: str, batch_size: Optional[int] = None + ) -> TFDataset: """Tensorflow deserializes parquet (GCS) --> tf.data.Dataset serialized_gcs_path is a folder containing one or more parquet files. """ + # Set default batch_size + batch_size = batch_size or DEFAULT_TENSORFLOW_BATCHSIZE + # Deserialization at remote environment try: import tensorflow_io as tfio @@ -1307,8 +1315,7 @@ def reduce_fn(a, b): return functools.reduce(reduce_fn, row.values()), target - # TODO(b/295535730): Remove hardcoded batch_size of 32 - return ds.map(map_fn).batch(32) + return ds.map(map_fn).batch(batch_size) class CloudPickleSerializer(serializers_base.Serializer):