diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java index 77ec4082f6f4b..b5969e318099d 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java @@ -67,6 +67,10 @@ public Schema configurationSchema() { .addNullableField("fetchSize", FieldType.INT16) .addNullableField("outputParallelization", FieldType.BOOLEAN) .addNullableField("autosharding", FieldType.BOOLEAN) + // Partitioning support. If you specify a partition column we will use that instead of + // readQuery + .addNullableField("partitionColumn", FieldType.STRING) + .addNullableField("partitions", FieldType.INT16) .build(); } @@ -110,26 +114,49 @@ public PTransform> buildReader() { return new PTransform>() { @Override public PCollection expand(PBegin input) { - @Nullable String readQuery = config.getString("readQuery"); - if (readQuery == null) { - readQuery = String.format("SELECT * FROM %s", location); - } - - JdbcIO.ReadRows readRows = - JdbcIO.readRows() - .withDataSourceConfiguration(getDataSourceConfiguration()) - .withQuery(readQuery); - - @Nullable Short fetchSize = config.getInt16("fetchSize"); - if (fetchSize != null) { - readRows = readRows.withFetchSize(fetchSize); - } - @Nullable Boolean outputParallelization = config.getBoolean("outputParallelization"); - if (outputParallelization != null) { - readRows = readRows.withOutputParallelization(outputParallelization); + // If we define a partition column we need to go a different route + @Nullable + String partitionColumn = + config.getSchema().hasField("partitionColumn") + ? config.getString("partitionColumn") + : null; + if (partitionColumn != null) { + JdbcIO.ReadWithPartitions readRows = + JdbcIO.readWithPartitions() + .withDataSourceConfiguration(getDataSourceConfiguration()) + .withTable(location) + .withPartitionColumn(partitionColumn) + .withRowOutput(); + @Nullable Short partitions = config.getInt16("partitions"); + if (partitions != null) { + readRows = readRows.withNumPartitions(partitions); + } + return input.apply(readRows); + } else { + + @Nullable String readQuery = config.getString("readQuery"); + if (readQuery == null) { + readQuery = String.format("SELECT * FROM %s", location); + } + + JdbcIO.ReadRows readRows = + JdbcIO.readRows() + .withDataSourceConfiguration(getDataSourceConfiguration()) + .withQuery(readQuery); + + @Nullable Short fetchSize = config.getInt16("fetchSize"); + if (fetchSize != null) { + readRows = readRows.withFetchSize(fetchSize); + } + + @Nullable Boolean outputParallelization = config.getBoolean("outputParallelization"); + if (outputParallelization != null) { + readRows = readRows.withOutputParallelization(outputParallelization); + } + + return input.apply(readRows); } - return input.apply(readRows); } }; } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java new file mode 100644 index 0000000000000..d91eaaef6e627 --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import javax.sql.DataSource; +import org.apache.beam.sdk.io.common.DatabaseTestHelper; +import org.apache.beam.sdk.io.common.TestRow; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class JdbcSchemaIOProviderTest { + + private static final JdbcIO.DataSourceConfiguration DATA_SOURCE_CONFIGURATION = + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.EmbeddedDriver", "jdbc:derby:memory:testDB;create=true"); + private static final int EXPECTED_ROW_COUNT = 1000; + + private static final DataSource DATA_SOURCE = DATA_SOURCE_CONFIGURATION.buildDatasource(); + private static final String READ_TABLE_NAME = DatabaseTestHelper.getTestTableName("UT_READ"); + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + @BeforeClass + public static void beforeClass() throws Exception { + // by default, derby uses a lock timeout of 60 seconds. In order to speed up the test + // and detect the lock faster, we decrease this timeout + System.setProperty("derby.locks.waitTimeout", "2"); + System.setProperty("derby.stream.error.file", "build/derby.log"); + + DatabaseTestHelper.createTable(DATA_SOURCE, READ_TABLE_NAME); + addInitialData(DATA_SOURCE, READ_TABLE_NAME); + } + + @Test + public void testPartitionedRead() { + JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider(); + + Row config = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get()) + .withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get()) + .withFieldValue("username", "") + .withFieldValue("password", "") + .withFieldValue("partitionColumn", "id") + .withFieldValue("partitions", (short) 10) + .build(); + JdbcSchemaIOProvider.JdbcSchemaIO schemaIO = + provider.from(READ_TABLE_NAME, config, Schema.builder().build()); + PCollection output = pipeline.apply(schemaIO.buildReader()); + Long expected = Long.valueOf(EXPECTED_ROW_COUNT); + PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected); + pipeline.run(); + } + + // This test shouldn't work because we only support numeric and datetime columns and we are trying + // to use a string + // column as our partition source + @Test + public void testPartitionedReadThatShouldntWork() throws Exception { + JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider(); + + Row config = + Row.withSchema(provider.configurationSchema()) + .withFieldValue("driverClassName", DATA_SOURCE_CONFIGURATION.getDriverClassName().get()) + .withFieldValue("jdbcUrl", DATA_SOURCE_CONFIGURATION.getUrl().get()) + .withFieldValue("username", "") + .withFieldValue("password", "") + .withFieldValue("partitionColumn", "name") + .withFieldValue("partitions", (short) 10) + .build(); + JdbcSchemaIOProvider.JdbcSchemaIO schemaIO = + provider.from(READ_TABLE_NAME, config, Schema.builder().build()); + PCollection output = pipeline.apply(schemaIO.buildReader()); + Long expected = Long.valueOf(EXPECTED_ROW_COUNT); + PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected); + try { + pipeline.run(); + } catch (Exception e) { + e.printStackTrace(); + return; + } + throw new Exception("Did not throw an exception"); + } + + /** Create test data that is consistent with that generated by TestRow. */ + private static void addInitialData(DataSource dataSource, String tableName) throws SQLException { + try (Connection connection = dataSource.getConnection()) { + connection.setAutoCommit(false); + try (PreparedStatement preparedStatement = + connection.prepareStatement(String.format("insert into %s values (?,?)", tableName))) { + for (int i = 0; i < EXPECTED_ROW_COUNT; i++) { + preparedStatement.clearParameters(); + preparedStatement.setInt(1, i); + preparedStatement.setString(2, TestRow.getNameForSeed(i)); + preparedStatement.executeUpdate(); + } + } + connection.commit(); + } + } +} diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd b/sdks/python/apache_beam/coders/coder_impl.pxd index 5714f8beeeece..0e6e31d0fc824 100644 --- a/sdks/python/apache_beam/coders/coder_impl.pxd +++ b/sdks/python/apache_beam/coders/coder_impl.pxd @@ -109,6 +109,10 @@ cdef class BooleanCoderImpl(CoderImpl): pass +cdef class BigEndianShortCoderImpl(StreamCoderImpl): + pass + + cdef class SinglePrecisionFloatCoderImpl(StreamCoderImpl): pass diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 094687ce68d84..cccc73662ce85 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -758,6 +758,22 @@ def estimate_size(self, unused_value, nested=False): if unused_value is not None else 0) +class BigEndianShortCoderImpl(StreamCoderImpl): + """For internal use only; no backwards-compatibility guarantees.""" + def encode_to_stream(self, value, out, nested): + # type: (int, create_OutputStream, bool) -> None + out.write_bigendian_int16(value) + + def decode_from_stream(self, in_stream, nested): + # type: (create_InputStream, bool) -> float + return in_stream.read_bigendian_int16() + + def estimate_size(self, unused_value, nested=False): + # type: (Any, bool) -> int + # A short is encoded as 2 bytes, regardless of nesting. + return 2 + + class SinglePrecisionFloatCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees.""" def encode_to_stream(self, value, out, nested): @@ -770,7 +786,7 @@ def decode_from_stream(self, in_stream, nested): def estimate_size(self, unused_value, nested=False): # type: (Any, bool) -> int - # A double is encoded as 8 bytes, regardless of nesting. + # A float is encoded as 4 bytes, regardless of nesting. return 4 diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 25fabc951c55f..d4ca99b80fb3e 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -682,6 +682,25 @@ def __hash__(self): Coder.register_structured_urn(common_urns.coders.VARINT.urn, VarIntCoder) +class BigEndianShortCoder(FastCoder): + """A coder used for big-endian int16 values.""" + def _create_impl(self): + return coder_impl.BigEndianShortCoderImpl() + + def is_deterministic(self): + # type: () -> bool + return True + + def to_type_hint(self): + return int + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + class SinglePrecisionFloatCoder(FastCoder): """A coder used for single-precision floating-point values.""" def _create_impl(self): diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index a0bec891bdf1e..7adb06cb28701 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -160,6 +160,7 @@ def tearDownClass(cls): coders.ListLikeCoder, coders.ProtoCoder, coders.ProtoPlusCoder, + coders.BigEndianShortCoder, coders.SinglePrecisionFloatCoder, coders.ToBytesCoder, coders.BigIntegerCoder, # tested in DecimalCoder diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py index 9dd4dcd9f6357..19424fa1f12b8 100644 --- a/sdks/python/apache_beam/coders/row_coder.py +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -22,6 +22,7 @@ from apache_beam.coders import typecoders from apache_beam.coders.coder_impl import LogicalTypeCoderImpl from apache_beam.coders.coder_impl import RowCoderImpl +from apache_beam.coders.coders import BigEndianShortCoder from apache_beam.coders.coders import BooleanCoder from apache_beam.coders.coders import BytesCoder from apache_beam.coders.coders import Coder @@ -153,6 +154,8 @@ def _nonnull_coder_from_type(field_type): if type_info == "atomic_type": if field_type.atomic_type in (schema_pb2.INT32, schema_pb2.INT64): return VarIntCoder() + if field_type.atomic_type == schema_pb2.INT16: + return BigEndianShortCoder() elif field_type.atomic_type == schema_pb2.FLOAT: return SinglePrecisionFloatCoder() elif field_type.atomic_type == schema_pb2.DOUBLE: diff --git a/sdks/python/apache_beam/coders/slow_stream.py b/sdks/python/apache_beam/coders/slow_stream.py index 11ccf7fd2e377..71a5b45d7691d 100644 --- a/sdks/python/apache_beam/coders/slow_stream.py +++ b/sdks/python/apache_beam/coders/slow_stream.py @@ -69,6 +69,9 @@ def write_bigendian_uint64(self, v): def write_bigendian_int32(self, v): self.write(struct.pack('>i', v)) + def write_bigendian_int16(self, v): + self.write(struct.pack('>h', v)) + def write_bigendian_double(self, v): self.write(struct.pack('>d', v)) @@ -172,6 +175,9 @@ def read_bigendian_uint64(self): def read_bigendian_int32(self): return struct.unpack('>i', self.read(4))[0] + def read_bigendian_int16(self): + return struct.unpack('>h', self.read(2))[0] + def read_bigendian_double(self): return struct.unpack('>d', self.read(8))[0] diff --git a/sdks/python/apache_beam/coders/stream.pxd b/sdks/python/apache_beam/coders/stream.pxd index fc179bb8c1b6c..97d66aa089a47 100644 --- a/sdks/python/apache_beam/coders/stream.pxd +++ b/sdks/python/apache_beam/coders/stream.pxd @@ -29,6 +29,7 @@ cdef class OutputStream(object): cpdef write_bigendian_int64(self, libc.stdint.int64_t signed_v) cpdef write_bigendian_uint64(self, libc.stdint.uint64_t signed_v) cpdef write_bigendian_int32(self, libc.stdint.int32_t signed_v) + cpdef write_bigendian_int16(self, libc.stdint.int16_t signed_v) cpdef write_bigendian_double(self, double d) cpdef write_bigendian_float(self, float d) @@ -46,6 +47,7 @@ cdef class ByteCountingOutputStream(OutputStream): cpdef write_bigendian_int64(self, libc.stdint.int64_t val) cpdef write_bigendian_uint64(self, libc.stdint.uint64_t val) cpdef write_bigendian_int32(self, libc.stdint.int32_t val) + cpdef write_bigendian_int16(self, libc.stdint.int16_t val) cpdef size_t get_count(self) cpdef bytes get(self) @@ -62,6 +64,7 @@ cdef class InputStream(object): cpdef libc.stdint.int64_t read_bigendian_int64(self) except? -1 cpdef libc.stdint.uint64_t read_bigendian_uint64(self) except? -1 cpdef libc.stdint.int32_t read_bigendian_int32(self) except? -1 + cpdef libc.stdint.int16_t read_bigendian_int16(self) except? -1 cpdef double read_bigendian_double(self) except? -1 cpdef float read_bigendian_float(self) except? -1 cpdef bytes read_all(self, bint nested=*) diff --git a/sdks/python/apache_beam/coders/stream.pyx b/sdks/python/apache_beam/coders/stream.pyx index 14536b007cc83..8f941c151bde7 100644 --- a/sdks/python/apache_beam/coders/stream.pyx +++ b/sdks/python/apache_beam/coders/stream.pyx @@ -101,6 +101,14 @@ cdef class OutputStream(object): self.data[self.pos + 3] = (v ) self.pos += 4 + cpdef write_bigendian_int16(self, libc.stdint.int16_t signed_v): + cdef libc.stdint.uint16_t v = signed_v + if self.buffer_size < self.pos + 2: + self.extend(2) + self.data[self.pos ] = (v >> 8) + self.data[self.pos + 1] = (v ) + self.pos += 2 + cpdef write_bigendian_double(self, double d): self.write_bigendian_int64((&d)[0]) @@ -157,6 +165,9 @@ cdef class ByteCountingOutputStream(OutputStream): cpdef write_bigendian_int32(self, libc.stdint.int32_t _): self.count += 4 + cpdef write_bigendian_int16(self, libc.stdint.int16_t _): + self.count += 2 + cpdef size_t get_count(self): return self.count @@ -237,6 +248,11 @@ cdef class InputStream(object): | self.allc[self.pos - 3] << 16 | self.allc[self.pos - 4] << 24) + cpdef libc.stdint.int16_t read_bigendian_int16(self) except? -1: + self.pos += 2 + return (self.allc[self.pos - 1] + | self.allc[self.pos - 2] << 8) + cpdef double read_bigendian_double(self) except? -1: cdef libc.stdint.int64_t as_long = self.read_bigendian_int64() return (&as_long)[0] diff --git a/sdks/python/apache_beam/coders/stream_test.py b/sdks/python/apache_beam/coders/stream_test.py index 35b64eb958138..57662056b2a02 100644 --- a/sdks/python/apache_beam/coders/stream_test.py +++ b/sdks/python/apache_beam/coders/stream_test.py @@ -139,6 +139,15 @@ def test_read_write_bigendian_int32(self): for v in values: self.assertEqual(v, in_s.read_bigendian_int32()) + def test_read_write_bigendian_int16(self): + values = 0, 1, -1, 2**15 - 1, -2**15, int(2**13 * math.pi) + out_s = self.OutputStream() + for v in values: + out_s.write_bigendian_int16(v) + in_s = self.InputStream(out_s.get()) + for v in values: + self.assertEqual(v, in_s.read_bigendian_int16()) + def test_byte_counting(self): bc_s = self.ByteCountingOutputStream() self.assertEqual(0, bc_s.get_count()) diff --git a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py index 1dcb56c51ecaf..ed8745ec2ac10 100644 --- a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py @@ -201,6 +201,24 @@ def test_xlang_jdbc_write_read(self, database): assert_that(result, equal_to(expected_row)) + # Try the same read using the partitioned reader code path. + # Outputs should be the same. + with TestPipeline() as p: + p.not_use_test_runner_api = True + result = ( + p + | 'Partitioned read from jdbc' >> ReadFromJdbc( + table_name=table_name, + partition_column='f_id', + partitions=3, + driver_class_name=self.driver_class_name, + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + classpath=classpath)) + + assert_that(result, equal_to(expected_row)) + # Creating a container with testcontainers sometimes raises ReadTimeout # error. In java there are 2 retries set by default. def start_db_container(self, retries, container_init): diff --git a/sdks/python/apache_beam/io/jdbc.py b/sdks/python/apache_beam/io/jdbc.py index 85b80fdea0e47..aa539871601d0 100644 --- a/sdks/python/apache_beam/io/jdbc.py +++ b/sdks/python/apache_beam/io/jdbc.py @@ -88,6 +88,8 @@ import typing +import numpy as np + from apache_beam.coders import RowCoder from apache_beam.transforms.external import BeamJarExpansionService from apache_beam.transforms.external import ExternalTransform @@ -113,19 +115,16 @@ def default_io_expansion_service(classpath=None): Config = typing.NamedTuple( 'Config', - [ - ('driver_class_name', str), - ('jdbc_url', str), - ('username', str), - ('password', str), - ('connection_properties', typing.Optional[str]), - ('connection_init_sqls', typing.Optional[typing.List[str]]), - ('read_query', typing.Optional[str]), - ('write_statement', typing.Optional[str]), - ('fetch_size', typing.Optional[int]), - ('output_parallelization', typing.Optional[bool]), - ('autosharding', typing.Optional[bool]), - ], + [('driver_class_name', str), ('jdbc_url', str), ('username', str), + ('password', str), ('connection_properties', typing.Optional[str]), + ('connection_init_sqls', typing.Optional[typing.List[str]]), + ('read_query', typing.Optional[str]), + ('write_statement', typing.Optional[str]), + ('fetch_size', typing.Optional[int]), + ('output_parallelization', typing.Optional[bool]), + ('autosharding', typing.Optional[bool]), + ('partition_column', typing.Optional[str]), + ('partitions', typing.Optional[np.int16])], ) DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16'] @@ -226,7 +225,8 @@ def __init__( fetch_size=None, output_parallelization=None, autosharding=autosharding, - ))), + partitions=None, + partition_column=None))), ), expansion_service or default_io_expansion_service(classpath), ) @@ -273,6 +273,8 @@ def __init__( query=None, output_parallelization=None, fetch_size=None, + partition_column=None, + partitions=None, connection_properties=None, connection_init_sqls=None, expansion_service=None, @@ -288,6 +290,10 @@ def __init__( :param query: sql query to be executed :param output_parallelization: is output parallelization on :param fetch_size: how many rows to fetch + :param partition_column: enable partitioned reads by splitting on this + column + :param partitions: override the default number of splits when using + partition_column :param connection_properties: properties of the jdbc connection passed as string with format [propertyName=property;]* @@ -324,7 +330,8 @@ def __init__( fetch_size=fetch_size, output_parallelization=output_parallelization, autosharding=None, - ))), + partition_column=partition_column, + partitions=partitions))), ), expansion_service or default_io_expansion_service(classpath), )