From abc8099d71b5edad9493c669b5f467e46013b204 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Mon, 11 Jul 2022 10:14:56 -0700 Subject: [PATCH] Allow one to bound the size of output shards when writing to files. (#22130) This fixes #22129. --- sdks/python/apache_beam/io/filebasedsink.py | 39 ++++++++++++++++++- sdks/python/apache_beam/io/iobase.py | 12 +++++- sdks/python/apache_beam/io/textio.py | 28 +++++++++++++- sdks/python/apache_beam/io/textio_test.py | 42 +++++++++++++++++++++ 4 files changed, 118 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/io/filebasedsink.py b/sdks/python/apache_beam/io/filebasedsink.py index a75e2c774436..6d8c6f8846fe 100644 --- a/sdks/python/apache_beam/io/filebasedsink.py +++ b/sdks/python/apache_beam/io/filebasedsink.py @@ -68,6 +68,9 @@ def __init__( shard_name_template=None, mime_type='application/octet-stream', compression_type=CompressionTypes.AUTO, + *, + max_records_per_shard=None, + max_bytes_per_shard=None, skip_if_empty=False): """ Raises: @@ -108,6 +111,8 @@ def __init__( shard_name_template) self.compression_type = compression_type self.mime_type = mime_type + self.max_records_per_shard = max_records_per_shard + self.max_bytes_per_shard = max_bytes_per_shard self.skip_if_empty = skip_if_empty def display_data(self): @@ -130,7 +135,13 @@ def open(self, temp_path): The returned file handle is passed to ``write_[encoded_]record`` and ``close``. """ - return FileSystems.create(temp_path, self.mime_type, self.compression_type) + writer = FileSystems.create( + temp_path, self.mime_type, self.compression_type) + if self.max_bytes_per_shard: + self.byte_counter = _ByteCountingWriter(writer) + return self.byte_counter + else: + return writer def write_record(self, file_handle, value): """Writes a single record go the file handle returned by ``open()``. @@ -406,10 +417,36 @@ def __init__(self, sink, temp_shard_path): self.sink = sink self.temp_shard_path = temp_shard_path self.temp_handle = self.sink.open(temp_shard_path) + self.num_records_written = 0 def write(self, value): + self.num_records_written += 1 self.sink.write_record(self.temp_handle, value) + def at_capacity(self): + return ( + self.sink.max_records_per_shard and + self.num_records_written >= self.sink.max_records_per_shard + ) or ( + self.sink.max_bytes_per_shard and + self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard) + def close(self): self.sink.close(self.temp_handle) return self.temp_shard_path + + +class _ByteCountingWriter: + def __init__(self, writer): + self.writer = writer + self.bytes_written = 0 + + def write(self, bs): + self.bytes_written += len(bs) + self.writer.write(bs) + + def flush(self): + self.writer.flush() + + def close(self): + self.writer.close() diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index fe46671aaa8e..6d75d520af55 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -849,7 +849,8 @@ class Writer(object): writing to a sink. """ def write(self, value): - """Writes a value to the sink using the current writer.""" + """Writes a value to the sink using the current writer. + """ raise NotImplementedError def close(self): @@ -863,6 +864,12 @@ def close(self): """ raise NotImplementedError + def at_capacity(self) -> bool: + """Returns whether this writer should be considered at capacity + and a new one should be created. + """ + return False + class Read(ptransform.PTransform): """A transform that reads a PCollection.""" @@ -1185,6 +1192,9 @@ def process(self, element, init_result): # We ignore UUID collisions here since they are extremely rare. self.writer = self.sink.open_writer(init_result, str(uuid.uuid4())) self.writer.write(element) + if self.writer.at_capacity(): + yield self.writer.close() + self.writer = None def finish_bundle(self): if self.writer is not None: diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 81d75bbe66ff..289c91e23b0a 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -435,6 +435,9 @@ def __init__(self, compression_type=CompressionTypes.AUTO, header=None, footer=None, + *, + max_records_per_shard=None, + max_bytes_per_shard=None, skip_if_empty=False): """Initialize a _TextSink. @@ -469,6 +472,14 @@ def __init__(self, append_trailing_newlines is set, '\n' will be added. footer: String to write at the end of file as a footer. If not None and append_trailing_newlines is set, '\n' will be added. + max_records_per_shard: Maximum number of records to write to any + individual shard. + max_bytes_per_shard: Target maximum number of bytes to write to any + individual shard. This may be exceeded slightly, as a new shard is + created once this limit is hit, but the remainder of a given record, a + subsequent newline, and a footer may cause the actual shard size + to exceed this value. This also tracks the uncompressed, + not compressed, size of the shard. skip_if_empty: Don't write any shards if the PCollection is empty. Returns: @@ -482,6 +493,8 @@ def __init__(self, coder=coder, mime_type='text/plain', compression_type=compression_type, + max_records_per_shard=max_records_per_shard, + max_bytes_per_shard=max_bytes_per_shard, skip_if_empty=skip_if_empty) self._append_trailing_newlines = append_trailing_newlines self._header = header @@ -791,6 +804,9 @@ def __init__( compression_type=CompressionTypes.AUTO, header=None, footer=None, + *, + max_records_per_shard=None, + max_bytes_per_shard=None, skip_if_empty=False): r"""Initialize a :class:`WriteToText` transform. @@ -830,6 +846,14 @@ def __init__( footer (str): String to write at the end of file as a footer. If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will be added. + max_records_per_shard: Maximum number of records to write to any + individual shard. + max_bytes_per_shard: Target maximum number of bytes to write to any + individual shard. This may be exceeded slightly, as a new shard is + created once this limit is hit, but the remainder of a given record, a + subsequent newline, and a footer may cause the actual shard size + to exceed this value. This also tracks the uncompressed, + not compressed, size of the shard. skip_if_empty: Don't write any shards if the PCollection is empty. """ @@ -843,7 +867,9 @@ def __init__( compression_type, header, footer, - skip_if_empty) + max_records_per_shard=max_records_per_shard, + max_bytes_per_shard=max_bytes_per_shard, + skip_if_empty=skip_if_empty) def expand(self, pcoll): return pcoll | Write(self._sink) diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index 6b4d6d2bb7e8..6fb8d6ccb362 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -1668,6 +1668,48 @@ def test_write_empty_skipped(self): outputs = list(glob.glob(self.path + '*')) self.assertEqual(outputs, []) + def test_write_max_records_per_shard(self): + records_per_shard = 13 + lines = [str(i).encode('utf-8') for i in range(100)] + with TestPipeline() as p: + # pylint: disable=expression-not-assigned + p | beam.core.Create(lines) | WriteToText( + self.path, max_records_per_shard=records_per_shard) + + read_result = [] + for file_name in glob.glob(self.path + '*'): + with open(file_name, 'rb') as f: + shard_lines = list(f.read().splitlines()) + self.assertLessEqual(len(shard_lines), records_per_shard) + read_result.extend(shard_lines) + self.assertEqual(sorted(read_result), sorted(lines)) + + def test_write_max_bytes_per_shard(self): + bytes_per_shard = 300 + max_len = 100 + lines = [b'x' * i for i in range(max_len)] + header = b'a' * 20 + footer = b'b' * 30 + with TestPipeline() as p: + # pylint: disable=expression-not-assigned + p | beam.core.Create(lines) | WriteToText( + self.path, + header=header, + footer=footer, + max_bytes_per_shard=bytes_per_shard) + + read_result = [] + for file_name in glob.glob(self.path + '*'): + with open(file_name, 'rb') as f: + contents = f.read() + self.assertLessEqual( + len(contents), bytes_per_shard + max_len + len(footer) + 2) + shard_lines = list(contents.splitlines()) + self.assertEqual(shard_lines[0], header) + self.assertEqual(shard_lines[-1], footer) + read_result.extend(shard_lines[1:-1]) + self.assertEqual(sorted(read_result), sorted(lines)) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO)