diff --git a/plugin/trino-redshift/README.md b/plugin/trino-redshift/README.md new file mode 100644 index 000000000000..16229b145da1 --- /dev/null +++ b/plugin/trino-redshift/README.md @@ -0,0 +1,20 @@ +# Redshift Connector + +To run the Redshift tests you will need to provision a Redshift cluster. The +tests are designed to run on the smallest possible Redshift cluster containing +is a single dc2.large instance. Additionally, you will need a S3 bucket +containing TPCH tiny data in Parquet format. The files should be named: + +``` +s3:///tpch/tiny/.parquet +``` + +To run the tests set the following system properties: + +``` +test.redshift.jdbc.endpoint=..redshift.amazonaws.com:5439/ +test.redshift.jdbc.user= +test.redshift.jdbc.password= +test.redshift.s3.tpch.tables.root= +test.redshift.iam.role= +``` diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index a6270049b7c2..367ae423d4a4 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,12 +23,32 @@ trino-base-jdbc + + io.trino + trino-matching + + + + io.trino + trino-plugin-toolkit + + + + io.airlift + configuration + + com.amazon.redshift redshift-jdbc42 2.1.0.9 + + com.google.guava + guava + + com.google.inject guice @@ -39,10 +59,27 @@ javax.inject + + org.jdbi + jdbi3-core + + - com.google.guava - guava + io.airlift + log + runtime + + + + io.airlift + log-manager + runtime + + + + net.jodah + failsafe runtime @@ -72,16 +109,91 @@ + + io.trino + trino-base-jdbc + test-jar + test + + io.trino trino-main test + + io.trino + trino-main + test-jar + test + + + + io.trino + trino-testing + test + + + + io.trino + trino-testing-services + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + io.airlift + testing + test + + + + org.assertj + assertj-core + test + + org.testng testng test + + + + default + + true + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/TestRedshiftAutomaticJoinPushdown.java + **/TestRedshiftConnectorTest.java + **/TestRedshiftTableStatisticsReader.java + **/TestRedshiftTypeMapping.java + + + + + + + diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java new file mode 100644 index 000000000000..f9c546105546 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementRedshiftAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java new file mode 100644 index 000000000000..103258db12b7 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DecimalType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_DECIMAL_PRECISION; +import static java.lang.String.format; + +public class ImplementRedshiftAvgDecimal + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleArgument().matching( + variable() + .with(type().matching(DecimalType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + DecimalType type = (DecimalType) columnHandle.getColumnType(); + verify(aggregateFunction.getOutputType().equals(type)); + + // When decimal type has maximum precision we can get result that is not matching Presto avg semantics. + if (type.getPrecision() == REDSHIFT_MAX_DECIMAL_PRECISION) { + return Optional.of(new JdbcExpression( + format("avg(CAST(%s AS decimal(%s, %s)))", context.rewriteExpression(input).orElseThrow(), type.getPrecision(), type.getScale()), + columnHandle.getJdbcTypeHandle())); + } + + // Redshift avg function rounds down resulting decimal. + // To match Presto avg semantics, we extend scale by 1 and round result to target scale. + return Optional.of(new JdbcExpression( + format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.rewriteExpression(input).orElseThrow(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index af0078309e9f..21ee2eabcf02 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -13,39 +13,105 @@ */ package io.trino.plugin.redshift; +import com.amazon.redshift.jdbc.RedshiftPreparedStatement; +import com.amazon.redshift.util.RedshiftObject; +import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableSet; +import io.airlift.slice.Slice; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; +import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcSplit; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.LongWriteFunction; +import io.trino.plugin.jdbc.ObjectReadFunction; +import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; +import io.trino.plugin.jdbc.SliceWriteFunction; +import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.aggregation.ImplementCount; +import io.trino.plugin.jdbc.aggregation.ImplementCountAll; +import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; +import io.trino.plugin.jdbc.aggregation.ImplementMinMax; +import io.trino.plugin.jdbc.aggregation.ImplementStddevPop; +import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp; +import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; +import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; +import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; +import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import javax.inject.Inject; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.OptionalLong; import java.util.function.BiFunction; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.base.Verify.verify; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMappingUsingSqlDate; import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunctionUsingSqlDate; @@ -56,6 +122,7 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; @@ -67,33 +134,168 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochSecondsAndFraction; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.TIME_MICROS; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; +import static io.trino.spi.type.Timestamps.roundDiv; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.Math.max; +import static java.lang.Math.min; import static java.lang.String.format; +import static java.math.RoundingMode.UNNECESSARY; +import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class RedshiftClient extends BaseJdbcClient { + /** + * Redshift does not handle values larger than 64 bits for + * {@code DECIMAL(19, s)}. It supports the full range of values for all + * other precisions. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_DECIMAL_CUTOFF_PRECISION = 19; + + static final int REDSHIFT_MAX_DECIMAL_PRECISION = 38; + + /** + * Maximum size of a {@link BigInteger} storing a Redshift {@code DECIMAL} + * with precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + */ + // actual value is 63 + private static final int REDSHIFT_DECIMAL_CUTOFF_BITS = BigInteger.valueOf(Long.MAX_VALUE).bitLength(); + + /** + * Maximum size of a Redshift CHAR column. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_MAX_CHAR = 4096; + + /** + * Maximum size of a Redshift VARCHAR column. + * + * @see + * Redshift documentation + */ + static final int REDSHIFT_MAX_VARCHAR = 65535; + + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("yyy-MM-dd[ G]"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("yyy-MM-dd HH:mm:ss") + .optionalStart() + .appendFraction(NANO_OF_SECOND, 0, 6, true) + .optionalEnd() + .appendPattern("[ G]") + .toFormatter(); + private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + + private final AggregateFunctionRewriter aggregateFunctionRewriter; + private final boolean statisticsEnabled; + private final RedshiftTableStatisticsReader statisticsReader; + @Inject - public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) + public RedshiftClient( + BaseJdbcConfig config, + ConnectionFactory connectionFactory, + JdbcStatisticsConfig statisticsConfig, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + .addStandardRules(this::quoted) + .build(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) + .add(new ImplementMinMax(true)) + .add(new ImplementSum(RedshiftClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementRedshiftAvgDecimal()) + .add(new ImplementRedshiftAvgBigint()) + .add(new ImplementStddevSamp()) + .add(new ImplementStddevPop()) + .add(new ImplementVarianceSamp()) + .add(new ImplementVariancePop()) + .build()); + + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); + this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of( + new JdbcTypeHandle( + Types.NUMERIC, + Optional.of("decimal"), + Optional.of(decimalType.getPrecision()), + Optional.of(decimalType.getScale()), + Optional.empty(), + Optional.empty())); + } + + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) + throws SQLException + { + Connection connection = super.getConnection(session, split, tableHandle); + try { + // super.getConnection sets read-only, since the connection is going to be used only for reads. + // However, for a complex query, Redshift may decide to create some temporary tables behind + // the scenes, and this requires the connection not to be read-only, otherwise Redshift + // may fail with "ERROR: transaction is read-only". + connection.setReadOnly(false); + } + catch (SQLException e) { + connection.close(); + throw e; + } + return connection; } @Override @@ -103,6 +305,87 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + } + + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return statisticsReader.readTableStatistics(session, handle, () -> this.getColumns(session, handle)); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + + @Override + public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) + { + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .map(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String nullsHandling = sortItem.getSortOrder().isNullsFirst() ? "NULLS FIRST" : "NULLS LAST"; + return format("%s %s %s", quoted(sortItem.getColumn().getColumnName()), ordering, nullsHandling); + }) + .collect(joining(", ")); + + return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); + }); + } + + @Override + public boolean isTopNGuaranteed(ConnectorSession session) + { + return true; + } + + @Override + protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) + { + return joinCondition.getOperator() != JoinCondition.Operator.IS_DISTINCT_FROM; + } + + @Override + public Optional implementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException @@ -131,7 +414,147 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) + { + checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle); + checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle); + checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle); + try (Connection connection = connectionFactory.openConnection(session)) { + verify(connection.getAutoCommit()); + PreparedQuery preparedQuery = queryBuilder.prepareDeleteQuery(this, session, connection, handle.getRequiredNamedRelation(), handle.getConstraint(), Optional.empty()); + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery)) { + int affectedRowsCount = preparedStatement.executeUpdate(); + // connection.getAutoCommit() == true is not enough to make DELETE effective and explicit commit is required + connection.commit(); + return OptionalLong.of(affectedRowsCount); + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + + @Override + protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) + throws SQLException + { + // Redshift truncates schema name to 127 chars silently + if (schemaName.length() > databaseMetadata.getMaxSchemaNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Schema name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); + } + } + + @Override + protected void verifyTableName(DatabaseMetaData databaseMetadata, String tableName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (tableName.length() > databaseMetadata.getMaxTableNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Table name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxTableNameLength(), tableName.length())); + } + } + + @Override + protected void verifyColumnName(DatabaseMetaData databaseMetadata, String columnName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (columnName.length() > databaseMetadata.getMaxColumnNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Column name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxColumnNameLength(), columnName.length())); + } + } + + @Override + public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle type) + { + Optional mapping = getForcedMappingToVarchar(type); + if (mapping.isPresent()) { + return mapping; + } + + if ("time".equals(type.getJdbcTypeName().orElse(""))) { + return Optional.of(ColumnMapping.longMapping( + TIME_MICROS, + RedshiftClient::readTime, + RedshiftClient::writeTime)); + } + + switch (type.getJdbcType()) { + case Types.BIT: // Redshift uses this for booleans + return Optional.of(booleanColumnMapping()); + + // case Types.TINYINT: -- Redshift doesn't support tinyint + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + + case Types.REAL: + return Optional.of(realColumnMapping()); + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + + case Types.NUMERIC: { + int precision = type.getRequiredColumnSize(); + int scale = type.getRequiredDecimalDigits(); + DecimalType decimalType = createDecimalType(precision, scale); + if (precision == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + return Optional.of(ColumnMapping.objectMapping( + decimalType, + longDecimalReadFunction(decimalType), + writeDecimalAtRedshiftCutoff(scale))); + } + return Optional.of(decimalColumnMapping(decimalType, UNNECESSARY)); + } + + case Types.CHAR: + CharType charType = createCharType(type.getRequiredColumnSize()); + return Optional.of(ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + RedshiftClient::writeChar)); + + case Types.VARCHAR: { + int length = type.getRequiredColumnSize(); + return Optional.of(varcharColumnMapping( + length < VarcharType.MAX_LENGTH + ? createVarcharType(length) + : createUnboundedVarcharType(), + true)); + } + + case Types.LONGVARBINARY: + return Optional.of(ColumnMapping.sliceMapping( + VARBINARY, + varbinaryReadFunction(), + varbinaryWriteFunction())); + + case Types.DATE: + return Optional.of(ColumnMapping.longMapping( + DATE, + RedshiftClient::readDate, + RedshiftClient::writeDate)); + + case Types.TIMESTAMP: + return Optional.of(ColumnMapping.longMapping( + TIMESTAMP_MICROS, + RedshiftClient::readTimestamp, + RedshiftClient::writeShortTimestamp)); + + case Types.TIMESTAMP_WITH_TIMEZONE: + return Optional.of(ColumnMapping.objectMapping( + TIMESTAMP_TZ_MICROS, + longTimestampWithTimeZoneReadFunction(), + longTimestampWithTimeZoneWriteFunction())); + } + + // Fall back to default behavior + return legacyToColumnMapping(session, type); + } + + private Optional legacyToColumnMapping(ConnectorSession session, JdbcTypeHandle typeHandle) { Optional mapping = getForcedMappingToVarchar(typeHandle); if (mapping.isPresent()) { @@ -150,6 +573,99 @@ public Optional toColumnMapping(ConnectorSession session, Connect @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { + if (BOOLEAN.equals(type)) { + return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); + } + if (TINYINT.equals(type)) { + // Redshift doesn't have tinyint + return WriteMapping.longMapping("smallint", tinyintWriteFunction()); + } + if (SMALLINT.equals(type)) { + return WriteMapping.longMapping("smallint", smallintWriteFunction()); + } + if (INTEGER.equals(type)) { + return WriteMapping.longMapping("integer", integerWriteFunction()); + } + if (BIGINT.equals(type)) { + return WriteMapping.longMapping("bigint", bigintWriteFunction()); + } + if (REAL.equals(type)) { + return WriteMapping.longMapping("real", realWriteFunction()); + } + if (DOUBLE.equals(type)) { + return WriteMapping.doubleMapping("double precision", doubleWriteFunction()); + } + + if (type instanceof DecimalType decimal) { + if (decimal.getPrecision() == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + // See doc for REDSHIFT_DECIMAL_CUTOFF_PRECISION + return WriteMapping.objectMapping( + format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()), + writeDecimalAtRedshiftCutoff(decimal.getScale())); + } + String name = format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()); + return decimal.isShort() + ? WriteMapping.longMapping(name, shortDecimalWriteFunction(decimal)) + : WriteMapping.objectMapping(name, longDecimalWriteFunction(decimal)); + } + + if (type instanceof CharType) { + // Redshift has no unbounded text/binary types, so if a CHAR is too + // large for Redshift, we write as VARCHAR. If too large for that, + // we use the largest VARCHAR Redshift supports. + int size = ((CharType) type).getLength(); + if (size <= REDSHIFT_MAX_CHAR) { + return WriteMapping.sliceMapping( + format("char(%d)", size), + RedshiftClient::writeChar); + } + int redshiftVarcharWidth = min(size, REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping( + format("varchar(%d)", redshiftVarcharWidth), + (statement, index, value) -> writeCharAsVarchar(statement, index, value, redshiftVarcharWidth)); + } + + if (type instanceof VarcharType) { + // Redshift has no unbounded text/binary types, so if a VARCHAR is + // larger than Redshift's limit, we make it that big instead. + int size = ((VarcharType) type).getLength() + .filter(n -> n <= REDSHIFT_MAX_VARCHAR) + .orElse(REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping(format("varchar(%d)", size), varcharWriteFunction()); + } + + if (VARBINARY.equals(type)) { + return WriteMapping.sliceMapping("varbyte", varbinaryWriteFunction()); + } + + if (DATE.equals(type)) { + return WriteMapping.longMapping("date", RedshiftClient::writeDate); + } + + if (type instanceof TimeType) { + return WriteMapping.longMapping("time", RedshiftClient::writeTime); + } + + if (type instanceof TimestampType) { + if (((TimestampType) type).isShort()) { + return WriteMapping.longMapping( + "timestamp", + RedshiftClient::writeShortTimestamp); + } + return WriteMapping.objectMapping( + "timestamp", + LongTimestamp.class, + RedshiftClient::writeLongTimestamp); + } + + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + if (timestampWithTimeZoneType.getPrecision() <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping("timestamptz", shortTimestampWithTimeZoneWriteFunction()); + } + return WriteMapping.objectMapping("timestamptz", longTimestampWithTimeZoneWriteFunction()); + } + + // Fall back to legacy behavior return legacyToWriteMapping(type); } @@ -183,9 +699,168 @@ private static String redshiftVarcharLiteral(String value) return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; } + private static ObjectReadFunction longTimestampWithTimeZoneReadFunction() + { + return ObjectReadFunction.of( + LongTimestampWithTimeZone.class, + (resultSet, columnIndex) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + OffsetDateTime offsetDateTime = resultSet.getObject(columnIndex, OffsetDateTime.class); + return fromEpochSecondsAndFraction( + offsetDateTime.toEpochSecond(), + (long) offsetDateTime.getNano() * PICOSECONDS_PER_NANOSECOND, + UTC_KEY); + }); + } + + private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() + { + return (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long millisUtc = unpackMillisUtc(value); + long epochSeconds = floorDiv(millisUtc, MILLISECONDS_PER_SECOND); + int nanosOfSecond = floorMod(millisUtc, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND; + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }; + } + + private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() + { + return ObjectWriteFunction.of( + LongTimestampWithTimeZone.class, + (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); + long nanosOfSecond = ((long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND) + + (value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND); + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }); + } + + private static void verifySupportedTimestampWithTimeZone(OffsetDateTime value) + { + if (value.isBefore(REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ)) { + DateTimeFormatter format = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS"); + throw new TrinoException( + INVALID_ARGUMENTS, + format("Minimum timestamp with time zone in Redshift is %s: %s", REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ.format(format), value.format(format))); + } + } + + /** + * Decimal write function for precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + * Ensures that values fit in 8 bytes. + */ + private static ObjectWriteFunction writeDecimalAtRedshiftCutoff(int scale) + { + return ObjectWriteFunction.of( + Int128.class, + (statement, index, decimal) -> { + BigInteger unscaled = decimal.toBigInteger(); + if (unscaled.bitLength() > REDSHIFT_DECIMAL_CUTOFF_BITS) { + throw new TrinoException(JDBC_NON_TRANSIENT_ERROR, format( + "Value out of range for Redshift DECIMAL(%d, %d)", + REDSHIFT_DECIMAL_CUTOFF_PRECISION, + scale)); + } + MathContext precision = new MathContext(REDSHIFT_DECIMAL_CUTOFF_PRECISION); + statement.setBigDecimal(index, new BigDecimal(unscaled, scale, precision)); + }); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but restrict to + * ASCII because Redshift only allows ASCII in {@code CHAR} values. + */ + private static void writeChar(PreparedStatement statement, int index, Slice slice) + throws SQLException + { + String value = slice.toStringUtf8(); + if (!CharMatcher.ascii().matchesAllOf(value)) { + throw new TrinoException( + JDBC_NON_TRANSIENT_ERROR, + format("Value for Redshift CHAR must be ASCII, but found '%s'", value)); + } + statement.setString(index, slice.toStringAscii()); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but pads + * the value with spaces to simulate {@code CHAR} semantics. + */ + private static void writeCharAsVarchar(PreparedStatement statement, int index, Slice slice, int columnLength) + throws SQLException + { + // Redshift counts varchar size limits in UTF-8 bytes, so this may make the string longer than + // the limit, but Redshift also truncates extra trailing spaces, so that doesn't cause any problems. + statement.setString(index, Chars.padSpaces(slice, columnLength).toStringUtf8()); + } + + private static void writeDate(PreparedStatement statement, int index, long day) + throws SQLException + { + statement.setObject(index, new RedshiftObject("date", DATE_FORMATTER.format(LocalDate.ofEpochDay(day)))); + } + + private static long readDate(ResultSet results, int index) + throws SQLException + { + // Reading date as string to workaround issues around julian->gregorian calendar switch + return LocalDate.parse(results.getString(index), DATE_FORMATTER).toEpochDay(); + } + + /** + * Write time with microsecond precision + */ + private static void writeTime(PreparedStatement statement, int index, long picos) + throws SQLException + { + statement.setObject(index, LocalTime.ofNanoOfDay((roundDiv(picos, PICOSECONDS_PER_MICROSECOND) % MICROSECONDS_PER_DAY) * NANOSECONDS_PER_MICROSECOND)); + } + + /** + * Read a time value with microsecond precision + */ + private static long readTime(ResultSet results, int index) + throws SQLException + { + return results.getObject(index, LocalTime.class).toNanoOfDay() * PICOSECONDS_PER_NANOSECOND; + } + + private static void writeShortTimestamp(PreparedStatement statement, int index, long epochMicros) + throws SQLException + { + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static void writeLongTimestamp(PreparedStatement statement, int index, Object value) + throws SQLException + { + LongTimestamp timestamp = (LongTimestamp) value; + long epochMicros = timestamp.getEpochMicros(); + if (timestamp.getPicosOfMicro() >= PICOSECONDS_PER_MICROSECOND / 2) { + epochMicros += 1; // Add one micro if picos round up + } + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static long readTimestamp(ResultSet results, int index) + throws SQLException + { + return StandardColumnMappings.toTrinoTimestamp(TIMESTAMP_MICROS, results.getObject(index, LocalDateTime.class)); + } + + private static SliceWriteFunction varbinaryWriteFunction() + { + return (statement, index, value) -> statement.unwrap(RedshiftPreparedStatement.class).setVarbyte(index, value.getBytes()); + } + private static Optional legacyDefaultColumnMapping(JdbcTypeHandle typeHandle) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated StandardColumnMappings.legacyDefaultColumnMapping() switch (typeHandle.getJdbcType()) { case Types.BIT: @@ -251,7 +926,6 @@ private static Optional legacyDefaultColumnMapping(JdbcTypeHandle private static WriteMapping legacyToWriteMapping(Type type) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated BaseJdbcClient.legacyToWriteMapping() if (type == BOOLEAN) { return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 53e1aee6ac29..aeffaac16ff7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -15,29 +15,39 @@ import com.amazon.redshift.Driver; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Provides; -import com.google.inject.Scopes; import com.google.inject.Singleton; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.DecimalModule; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; +import java.util.Properties; + +import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; public class RedshiftClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + public void setup(Binder binder) { - binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(Scopes.SINGLETON); - newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); + + install(new DecimalModule()); + install(new JdbcJoinPushdownSupportModule()); } @Singleton @@ -45,6 +55,14 @@ public void configure(Binder binder) @ForBaseJdbc public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider) { - return new DriverConnectionFactory(new Driver(), config, credentialProvider); + return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), getDriverProperties(), credentialProvider); + } + + private static Properties getDriverProperties() + { + Properties properties = new Properties(); + properties.put("reWriteBatchedInserts", "true"); + properties.put("reWriteBatchedInsertsSize", "512"); + return properties; } } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..c576abdd109d --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class RedshiftTableStatisticsReader +{ + private final ConnectionFactory connectionFactory; + + public RedshiftTableStatisticsReader(ConnectionFactory connectionFactory) + { + this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory is null"); + } + + public TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, Supplier> columnSupplier) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional optionalRowCount = readRowCountTableStat(statisticsDao, table); + if (optionalRowCount.isEmpty()) { + // Table not found + return TableStatistics.empty(); + } + long rowCount = optionalRowCount.get(); + + TableStatistics.Builder tableStatistics = TableStatistics.builder() + .setRowCount(Estimate.of(rowCount)); + + if (rowCount == 0) { + return tableStatistics.build(); + } + + Map columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); + + for (JdbcColumnHandle column : columnSupplier.get()) { + ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); + if (result == null) { + continue; + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(result.nullsFraction() + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDistinctValuesCount(result.distinctValuesIndicator() + .map(distinctValuesIndicator -> { + // If the distinct value count is an estimate Redshift uses "the negative of the number of distinct values divided by the number of rows + // For example, -1 indicates a unique column in which the number of distinct values is the same as the number of rows." + // https://www.postgresql.org/docs/9.3/view-pg-stats.html + if (distinctValuesIndicator < 0.0) { + return Math.min(-distinctValuesIndicator * rowCount, rowCount); + } + return distinctValuesIndicator; + }) + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDataSize(result.averageColumnLength() + .flatMap(averageColumnLength -> + result.nullsFraction() + .map(nullsFraction -> 1.0 * averageColumnLength * rowCount * (1 - nullsFraction)) + .map(Estimate::of)) + .orElseGet(Estimate::unknown)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Optional readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional rowCount = statisticsDao.getRowCountFromPgClass(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + if (rowCount.isEmpty()) { + // Table not found + return Optional.empty(); + } + + if (rowCount.get() == 0) { + // `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT) + // The `pg_stat_all_tables` view can be way off, so we use it only as a fallback + rowCount = statisticsDao.getRowCountFromPgStat(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + } + + return rowCount; + } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Optional getRowCountFromPgClass(String schema, String tableName) + { + return handle.createQuery("SELECT reltuples FROM pg_class WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountFromPgStat(String schema, String tableName) + { + // Redshift does not have the Postgres `n_live_tup`, so estimate from `inserts - deletes` + return handle.createQuery("SELECT n_tup_ins - n_tup_del FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + List getColumnStatistics(String schema, String tableName) + { + return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .map((rs, ctx) -> + new ColumnStatisticsResult( + requireNonNull(rs.getString("attname"), "attname is null"), + Optional.of(rs.getFloat("null_frac")), + Optional.of(rs.getFloat("n_distinct")), + Optional.of(rs.getInt("avg_width")))) + .list(); + } + } + + // TODO remove when error prone is updated for Java 17 records + @SuppressWarnings("unused") + private record ColumnStatisticsResult(String columnName, Optional nullsFraction, Optional distinctValuesIndicator, Optional averageColumnLength) {} +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java new file mode 100644 index 000000000000..3e96738e7ba1 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -0,0 +1,271 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; +import io.airlift.log.Logger; +import io.airlift.log.Logging; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleCallback; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.copyTable; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.assertions.Assert.assertEquals; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableSet; + +public final class RedshiftQueryRunner +{ + private static final Logger log = Logger.get(RedshiftQueryRunner.class); + private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); + static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); + private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); + + private static final String TEST_DATABASE = "testdb"; + private static final String TEST_CATALOG = "redshift"; + static final String TEST_SCHEMA = "test_schema"; + + static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + + private static final String CONNECTOR_NAME = "redshift"; + private static final String TPCH_CATALOG = "tpch"; + + private static final String GRANTED_USER = "alice"; + private static final String NON_GRANTED_USER = "bob"; + + private RedshiftQueryRunner() {} + + public static DistributedQueryRunner createRedshiftQueryRunner( + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + return createRedshiftQueryRunner( + createSession(), + extraProperties, + connectorProperties, + tables); + } + + public static DistributedQueryRunner createRedshiftQueryRunner( + Session session, + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + DistributedQueryRunner.Builder builder = DistributedQueryRunner.builder(session); + extraProperties.forEach(builder::addExtraProperty); + DistributedQueryRunner runner = builder.build(); + try { + runner.installPlugin(new TpchPlugin()); + runner.createCatalog(TPCH_CATALOG, "tpch", Map.of()); + + Map properties = new HashMap<>(connectorProperties); + properties.putIfAbsent("connection-url", JDBC_URL); + properties.putIfAbsent("connection-user", JDBC_USER); + properties.putIfAbsent("connection-password", JDBC_PASSWORD); + + runner.installPlugin(new RedshiftPlugin()); + runner.createCatalog(TEST_CATALOG, CONNECTOR_NAME, properties); + + executeInRedshift("CREATE SCHEMA IF NOT EXISTS " + TEST_SCHEMA); + createUserIfNotExists(NON_GRANTED_USER, JDBC_PASSWORD); + createUserIfNotExists(GRANTED_USER, JDBC_PASSWORD); + + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", TEST_DATABASE, GRANTED_USER)); + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + + provisionTables(session, runner, tables); + + // This step is necessary for product tests + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + } + catch (Throwable e) { + closeAllSuppress(e, runner); + throw e; + } + return runner; + } + + private static Session createSession() + { + return createSession(GRANTED_USER); + } + + private static Session createSession(String user) + { + return testSessionBuilder() + .setCatalog(TEST_CATALOG) + .setSchema(TEST_SCHEMA) + .setIdentity(Identity.ofUser(user)) + .build(); + } + + private static void createUserIfNotExists(String user, String password) + { + try { + executeInRedshift("CREATE USER " + user + " PASSWORD " + "'" + password + "'"); + } + catch (Exception e) { + // if user already exists, swallow the exception + if (!e.getMessage().matches(".*user \"" + user + "\" already exists.*")) { + throw e; + } + } + } + + private static void executeInRedshiftWithRetry(String sql) + { + Failsafe.with(new RetryPolicy<>() + .handleIf(e -> e.getMessage().matches(".* concurrent transaction .*")) + .withDelay(Duration.ofSeconds(10)) + .withMaxRetries(3)) + .run(() -> executeInRedshift(sql)); + } + + public static void executeInRedshift(String sql, Object... parameters) + { + executeInRedshift(handle -> handle.execute(sql, parameters)); + } + + public static void executeInRedshift(HandleConsumer consumer) + throws E + { + executeWithRedshift(consumer.asCallback()); + } + + public static T executeWithRedshift(HandleCallback callback) + throws E + { + return Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(callback); + } + + private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) + { + Set existingTables = queryRunner.listTables(session, session.getCatalog().orElseThrow(), session.getSchema().orElseThrow()) + .stream() + .map(QualifiedObjectName::getObjectName) + .collect(toUnmodifiableSet()); + + Streams.stream(tables) + .map(table -> table.getTableName().toLowerCase(ENGLISH)) + .filter(name -> !existingTables.contains(name)) + .forEach(name -> copyFromS3(queryRunner, session, name)); + + for (TpchTable tpchTable : tables) { + verifyLoadedDataHasSameSchema(session, queryRunner, tpchTable); + } + } + + private static void copyFromS3(QueryRunner queryRunner, Session session, String name) + { + String s3Path = format("%s/%s/%s.parquet", S3_TPCH_TABLES_ROOT, TPCH_CATALOG, name); + log.info("Creating table %s in Redshift copying from %s", name, s3Path); + + // Create table in ephemeral Redshift cluster with no data + String createTableSql = format("CREATE TABLE %s.%s.%s AS ", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), name) + + format("SELECT * FROM %s.%s.%s WITH NO DATA", TPCH_CATALOG, TINY_SCHEMA_NAME, name); + queryRunner.execute(session, createTableSql); + + // Copy data from S3 bucket to ephemeral Redshift + String copySql = "COPY " + TEST_SCHEMA + "." + name + + " FROM '" + s3Path + "'" + + " IAM_ROLE '" + IAM_ROLE + "'" + + " FORMAT PARQUET"; + executeInRedshiftWithRetry(copySql); + } + + private static void copyFromTpchCatalog(QueryRunner queryRunner, Session session, String name) + { + // This function exists in case we need to copy data from the TPCH catalog rather than S3, + // such as moving to a new AWS account or if the schema changes. We can swap this method out for + // copyFromS3 in provisionTables and then export the data again to S3. + copyTable(queryRunner, TPCH_CATALOG, TINY_SCHEMA_NAME, name, session); + } + + private static void verifyLoadedDataHasSameSchema(Session session, QueryRunner queryRunner, TpchTable tpchTable) + { + // We want to verify that the loaded data has the same schema as if we created a fresh table from the TPC-H catalog + // If this assertion fails, we may need to recreate the Redshift tables from the TPC-H catalog and unload the data to S3 + try { + long expectedCount = (long) queryRunner.execute("SELECT count(*) FROM " + format("%s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())).getOnlyValue(); + long actualCount = (long) queryRunner.execute( + "SELECT count(*) FROM " + format( + "%s.%s.%s", + session.getCatalog().orElseThrow(), + session.getSchema().orElseThrow(), + tpchTable.getTableName())).getOnlyValue(); + + if (expectedCount != actualCount) { + throw new RuntimeException(format("Table %s is not loaded correctly. Expected %s rows got %s", tpchTable.getTableName(), expectedCount, actualCount)); + } + + log.info("Checking column types on table %s", tpchTable.getTableName()); + MaterializedResult expectedColumns = queryRunner.execute(format("DESCRIBE %s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())); + MaterializedResult actualColumns = queryRunner.execute("DESCRIBE " + tpchTable.getTableName()); + assertEquals(actualColumns, expectedColumns); + } + catch (Exception e) { + throw new RuntimeException("Failed to assert columns for TPC-H table " + tpchTable.getTableName(), e); + } + } + + /** + * Get the named system property, throwing an exception if it is not set. + */ + private static String requireSystemProperty(String property) + { + return requireNonNull(System.getProperty(property), property + " is not set"); + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + DistributedQueryRunner queryRunner = createRedshiftQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of(), + ImmutableList.of()); + + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java new file mode 100644 index 000000000000..3509f8dd8b9c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; +import org.testng.SkipException; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class TestRedshiftAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableList.of()); + } + + @Override + public void testJoinPushdownWithEmptyStatsInitially() + { + throw new SkipException("Redshift table statistics are automatically populated"); + } + + @Override + protected void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery(format("SELECT count(*) FROM %s.%s", TEST_SCHEMA, tableName)) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery( + "SELECT reltuples FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name") + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + } + throw new IllegalStateException("Stats not gathered"); + }); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java new file mode 100644 index 000000000000..1b16e335bd82 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -0,0 +1,639 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.Session; +import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.tpch.TpchTable; +import org.testng.SkipException; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeWithRedshift; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftConnectorTest + extends BaseJdbcConnectorTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + // NOTE this can cause tests to time-out if larger tables like + // lineitem and orders need to be re-created. + TpchTable.getTables()); + } + + @Override + @SuppressWarnings("DuplicateBranchesInSwitch") // options here are grouped per-feature + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_COMMENT_ON_TABLE: + case SUPPORTS_ADD_COLUMN_WITH_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT: + return false; + + case SUPPORTS_ARRAY: + case SUPPORTS_ROW_TYPE: + return false; + + case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: + return false; + + case SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV: + case SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE: + case SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT: + return true; + + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY: + return true; + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + + @Override + protected TestTable createTableWithDefaultColumns() + { + return new TestTable( + onRemoteDatabase(), + format("%s.test_table_with_default_columns", TEST_SCHEMA), + "(col_required BIGINT NOT NULL," + + "col_nullable BIGINT," + + "col_default BIGINT DEFAULT 43," + + "col_nonnull_default BIGINT NOT NULL DEFAULT 42," + + "col_required2 BIGINT NOT NULL)"); + } + + @Override + protected Optional filterDataMappingSmokeTestData(DataMappingTestSetup dataMappingTestSetup) + { + String typeName = dataMappingTestSetup.getTrinoTypeName(); + if ("date".equals(typeName)) { + if (dataMappingTestSetup.getSampleValueLiteral().equals("DATE '1582-10-05'")) { + return Optional.empty(); + } + } + return Optional.of(dataMappingTestSetup); + } + + /** + * Overridden due to Redshift not supporting non-ASCII characters in CHAR. + */ + @Override + public void testCreateTableAsSelectWithUnicode() + { + assertThatThrownBy(super::testCreateTableAsSelectWithUnicode) + .hasStackTraceContaining("Value too long for character type"); + // NOTE we add a copy of the above using VARCHAR which supports non-ASCII characters + assertCreateTableAsSelect( + "SELECT CAST('\u2603' AS VARCHAR) unicode", + "SELECT 1"); + } + + @Test(dataProvider = "redshiftTypeToTrinoTypes") + public void testReadFromLateBindingView(String redshiftType, String trinoType) + { + try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) { + assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName()))) + .projected(1) + .containsAll("VALUES (true)"); + + assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) + .projected(1) + .skippingTypesCheck() + .containsAll("VALUES ('%s')".formatted(trinoType)); + } + } + + @DataProvider + public Object[][] redshiftTypeToTrinoTypes() + { + return new Object[][] { + {"SMALLINT", "smallint"}, + {"INTEGER", "integer"}, + {"BIGINT", "bigint"}, + {"DECIMAL", "decimal(18,0)"}, + {"REAL", "real"}, + {"DOUBLE PRECISION", "double"}, + {"BOOLEAN", "boolean"}, + {"CHAR(1)", "char(1)"}, + {"VARCHAR(1)", "varchar(1)"}, + {"TIME", "time(6)"}, + {"TIMESTAMP", "timestamp(6)"}, + {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; + } + + @Override + public void testDelete() + { + // The base tests is very slow because Redshift CTAS is really slow, so use a smaller test + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_", "AS SELECT * FROM nation")) { + // delete without matching any rows + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey < 0", 0); + + // delete with a predicate that optimizes to false + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey > 5 AND nationkey < 4", 0); + + // delete successive parts of the table + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 5", "SELECT count(*) FROM nation WHERE nationkey <= 5"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 5"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 10", "SELECT count(*) FROM nation WHERE nationkey > 5 AND nationkey <= 10"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 10"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 15", "SELECT count(*) FROM nation WHERE nationkey > 10 AND nationkey <= 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 15"); + + // delete remaining + assertUpdate("DELETE FROM " + table.getName(), "SELECT count(*) FROM nation WHERE nationkey > 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE false"); + } + } + + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + try { + assertUpdate( + "CREATE TABLE " + TEST_SCHEMA + "." + tableName + + " AS SELECT " + + " custkey AS CASE_UNQUOTED_UPPER, " + + " name AS case_unquoted_lower, " + + " address AS cASe_uNQuoTeD_miXED, " + + " nationkey AS \"CASE_QUOTED_UPPER\", " + + " phone AS \"case_quoted_lower\"," + + " acctbal AS \"CasE_QuoTeD_miXED\" " + + "FROM customer", + 1500); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + TEST_SCHEMA + "." + tableName, + "VALUES " + + "('case_unquoted_upper', NULL, 1485, 0, null, null, null)," + + "('case_unquoted_lower', 33000, 1470, 0, null, null, null)," + + "('case_unquoted_mixed', 42000, 1500, 0, null, null, null)," + + "('case_quoted_upper', NULL, 25, 0, null, null, null)," + + "('case_quoted_lower', 28500, 1483, 0, null, null, null)," + + "('case_quoted_mixed', NULL, 1483, 0, null, null, null)," + + "(null, null, null, null, 1500, null, null)"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + /** + * Tries to create situation where Redshift would decide to materialize a temporary table for query sent to it by us. + * Such temporary table requires that our Connection is not read-only. + */ + @Test + public void testComplexPushdownThatMayElicitTemporaryTable() + { + int subqueries = 10; + String subquery = "SELECT custkey, count(*) c FROM orders GROUP BY custkey"; + StringBuilder sql = new StringBuilder(); + sql.append(format( + "SELECT t0.custkey, %s c_sum ", + IntStream.range(0, subqueries) + .mapToObj(i -> format("t%s.c", i)) + .collect(Collectors.joining("+")))); + sql.append(format("FROM (%s) t0 ", subquery)); + for (int i = 1; i < subqueries; i++) { + sql.append(format("JOIN (%s) t%s ON t0.custkey = t%s.custkey ", subquery, i, i)); + } + sql.append("WHERE t0.custkey = 1045 OR rand() = 42"); + + Session forceJoinPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + + assertThat(query(forceJoinPushdown, sql.toString())) + .matches(format("SELECT max(custkey), count(*) * %s FROM tpch.tiny.orders WHERE custkey = 1045", subqueries)); + } + + private static void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + TEST_SCHEMA + "." + tableName) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery(""" + SELECT reltuples FROM pg_class + WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) + AND relname = :table_name + """) + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + @DataProvider + public Object[][] testCaseColumnNamesDataProvider() + { + return new Object[][] { + {"TEST_STATS_MIXED_UNQUOTED_UPPER_" + randomNameSuffix()}, + {"test_stats_mixed_unquoted_lower_" + randomNameSuffix()}, + {"test_stats_mixed_uNQuoTeD_miXED_" + randomNameSuffix()}, + {"\"TEST_STATS_MIXED_QUOTED_UPPER_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_quoted_lower_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_QuoTeD_miXED_" + randomNameSuffix() + "\""} + }; + } + + @Override + public void testCountDistinctWithStringTypes() + { + // cannot test using generic method as Redshift does not allow non-ASCII characters in CHAR values. + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasMessageContaining("Value for Redshift CHAR must be ASCII, but found 'ą'"); + + List rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ") + .map(value -> format("'%1$s', '%1$s'", value)) + .collect(toImmutableList()); + String tableName = "distinct_strings" + randomNameSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) { + // Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + // Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES (BIGINT '6', BIGINT '6')") + .isFullyPushedDown(); + } + } + + @Override + public void testAggregationPushdown() + { + throw new SkipException("tested in testAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testAggregationPushdown(String distStyle) + { + String nation = format("%s.nation_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + String customer = format("%s.customer_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + try { + copyWithDistStyle(TEST_SCHEMA + ".nation", nation, distStyle, Optional.of("regionkey")); + copyWithDistStyle(TEST_SCHEMA + ".customer", customer, distStyle, Optional.of("nationkey")); + + // TODO support aggregation pushdown with GROUPING SETS + // TODO support aggregation over expressions + + // count() + assertThat(query("SELECT count(*) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT regionkey, count(1) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT count(*) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(a_bigint) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT a_bigint, count(1) FROM " + emptyTableName + " GROUP BY a_bigint")).isFullyPushedDown(); + } + + // GROUP BY + assertThat(query("SELECT regionkey, min(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT t_double, min(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, max(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, sum(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, avg(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + } + + // GROUP BY and WHERE on bigint column + // GROUP BY and WHERE on aggregation key + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown(); + + // GROUP BY and WHERE on varchar column + // GROUP BY and WHERE on "other" (not aggregation key, not aggregation input) + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above WHERE and LIMIT + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT * FROM " + nation + " WHERE regionkey < 2 LIMIT 11) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above TopN + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT regionkey, nationkey FROM " + nation + " ORDER BY nationkey ASC LIMIT 10) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY with JOIN + assertThat(query( + joinPushdownEnabled(getSession()), + "SELECT n.regionkey, sum(c.acctbal) acctbals FROM " + nation + " n LEFT JOIN " + customer + " c USING (nationkey) GROUP BY 1")) + .isFullyPushedDown(); + // GROUP BY with WHERE on neither grouping nor aggregation column + assertThat(query("SELECT nationkey, min(regionkey) FROM " + nation + " WHERE name = 'ARGENTINA' GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column + assertThat(query("SELECT count(name) FROM " + nation)).isFullyPushedDown(); + // aggregation on varchar column with GROUPING + assertThat(query("SELECT nationkey, count(name) FROM " + nation + " GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column with WHERE + assertThat(query("SELECT count(name) FROM " + nation + " WHERE name = 'ARGENTINA'")).isFullyPushedDown(); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + nation); + executeInRedshift("DROP TABLE IF EXISTS " + customer); + } + } + + @Override + public void testNumericAggregationPushdown() + { + throw new SkipException("tested in testNumericAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testNumericAggregationPushdown(String distStyle) + { + String schemaName = getSession().getSchema().orElseThrow(); + // empty table + try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + } + + try (TestTable testTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", + ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) { + String testTableName = testTable.getName() + "_" + distStyle; + copyWithDistStyle(testTable.getName(), testTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTableName)).isFullyPushedDown(); + + // smoke testing of more complex cases + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110")).isFullyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + } + } + + private static void copyWithDistStyle(String sourceTableName, String destTableName, String distStyle, Optional distKey) + { + if (distStyle.equals("AUTO")) { + // NOTE: Redshift doesn't support setting diststyle AUTO in CTAS statements + executeInRedshift("CREATE TABLE " + destTableName + " AS SELECT * FROM " + sourceTableName); + // Redshift doesn't allow ALTER DISTSTYLE if original and new style are same, so we need to check current diststyle of table + boolean isDistStyleAuto = executeWithRedshift(handle -> { + Optional currentDistStyle = handle.createQuery("" + + "SELECT releffectivediststyle " + + "FROM pg_class_info AS a LEFT JOIN pg_namespace AS b ON a.relnamespace = b.oid " + + "WHERE lower(nspname) = lower(:schema_name) AND lower(relname) = lower(:table_name)") + .bind("schema_name", TEST_SCHEMA) + // destTableName = TEST_SCHEMA + "." + tableName + .bind("table_name", destTableName.substring(destTableName.indexOf(".") + 1)) + .mapTo(Long.class) + .findOne(); + + // 10 means AUTO(ALL) and 11 means AUTO(EVEN). See https://docs.aws.amazon.com/redshift/latest/dg/r_PG_CLASS_INFO.html. + return currentDistStyle.isPresent() && (currentDistStyle.get() == 10 || currentDistStyle.get() == 11); + }); + if (!isDistStyleAuto) { + executeInRedshift("ALTER TABLE " + destTableName + " ALTER DISTSTYLE " + distStyle); + } + } + else { + String copyWithDistStyleSql = "CREATE TABLE " + destTableName + " DISTSTYLE " + distStyle; + if (distStyle.equals("KEY")) { + copyWithDistStyleSql += format(" DISTKEY(%s)", distKey.orElseThrow()); + } + copyWithDistStyleSql += " AS SELECT * FROM " + sourceTableName; + executeInRedshift(copyWithDistStyleSql); + } + } + + @DataProvider + public Object[][] testAggregationPushdownDistStylesDataProvider() + { + return new Object[][] { + {"EVEN"}, + {"KEY"}, + {"ALL"}, + {"AUTO"}, + }; + } + + @Test + public void testDecimalAvgPushdownForMaximumDecimalScale() + { + List rows = ImmutableList.of( + "12345789.9876543210", + format("%s.%s", "1".repeat(28), "9".repeat(10))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(38, 10))", rows)) { + // Redshift avg rounds down decimal result which doesn't match Presto semantics + assertThatThrownBy(() -> assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown()) + .isInstanceOf(AssertionError.class) + .hasMessageContaining(""" + elements not found: + <(555555555555555555561728450.9938271605)> + and elements not expected: + <(555555555555555555561728450.9938271604)> + """); + } + } + + @Test + public void testDecimalAvgPushdownFoShortDecimalScale() + { + List rows = ImmutableList.of( + "0.987654321234567890", + format("0.%s", "1".repeat(18))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(18, 18))", rows)) { + assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + + @Override + @Test + public void testReadMetadataWithRelationsConcurrentModifications() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + public void testInsertRowConcurrently() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } + + @Override + protected String errorMessageForInsertIntoNotNullColumn(String columnName) + { + return format("(?s).*Cannot insert a NULL value into column %s.*", columnName); + } + + @Override + protected OptionalInt maxSchemaNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifySchemaNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Schema name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected OptionalInt maxTableNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifyTableNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Table name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected OptionalInt maxColumnNameLength() + { + return OptionalInt.of(127); + } + + @Override + protected void verifyColumnNameLengthFailurePermissible(Throwable e) + { + assertThat(e).hasMessage("Column name must be shorter than or equal to '127' characters but got '128'"); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return RedshiftQueryRunner::executeInRedshift; + } + + @Override + public void testDeleteWithLike() + { + assertThatThrownBy(super::testDeleteWithLike) + .hasStackTraceContaining("TrinoException: This connector does not support modifying table rows"); + } + + @Test + @Override + public void testAddNotNullColumnToNonEmptyTable() + { + throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); + } + + private static class TestView + implements AutoCloseable + { + private final String name; + private final SqlExecutor executor; + + public TestView(SqlExecutor executor, String namePrefix, String viewDefinition) + { + this.executor = executor; + this.name = namePrefix + "_" + randomNameSuffix(); + executor.execute("CREATE OR REPLACE VIEW " + name + " AS " + viewDefinition); + } + + @Override + public void close() + { + executor.execute("DROP VIEW " + name); + } + + public String getName() + { + return name; + } + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..ff713337ea53 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java @@ -0,0 +1,349 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.amazon.redshift.Driver; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConfig; +import io.trino.plugin.jdbc.DriverConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.plugin.jdbc.credential.StaticCredentialProvider; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.type.VarcharType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.SoftAssertions; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.sql.Types; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static java.util.Collections.emptyMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.from; +import static org.assertj.core.api.Assertions.withinPercentage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestRedshiftTableStatisticsReader + extends AbstractTestQueryFramework +{ + private static final JdbcTypeHandle BIGINT_TYPE_HANDLE = new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + private static final JdbcTypeHandle DOUBLE_TYPE_HANDLE = new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + private static final List CUSTOMER_COLUMNS = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("name", 25), + createVarcharJdbcColumnHandle("address", 48), + new JdbcColumnHandle("nationkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("phone", 15), + new JdbcColumnHandle("acctbal", DOUBLE_TYPE_HANDLE, DOUBLE), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + private RedshiftTableStatisticsReader statsReader; + + @BeforeClass + public void setup() + { + DriverConnectionFactory connectionFactory = new DriverConnectionFactory( + new Driver(), + new BaseJdbcConfig().setConnectionUrl(JDBC_URL), + new StaticCredentialProvider(Optional.of(JDBC_USER), Optional.of(JDBC_PASSWORD))); + statsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), ImmutableList.of(CUSTOMER)); + } + + @Test + public void testCustomerTable() + throws Exception + { + assertThat(collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer", CUSTOMER_COLUMNS)) + .returns(Estimate.of(1500), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(0), statsCloseTo(1500.0, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(1), statsCloseTo(1500.0, 0.0, 33000.0)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(3), statsCloseTo(25.000, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(5), statsCloseTo(1499.0, 0.0, 8.0 * 1500)); + } + + @Test + public void testEmptyTable() + throws Exception + { + TableStatistics tableStatistics = collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer WHERE false", CUSTOMER_COLUMNS); + assertThat(tableStatistics) + .returns(Estimate.of(0.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + + @Test + public void testAllNulls() + throws Exception + { + String tableName = "testallnulls_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " (i BIGINT)"); + executeInRedshift("INSERT INTO " + schemaAndTable + " (i) VALUES (NULL)"); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + + TableStatistics stats = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> ImmutableList.of(new JdbcColumnHandle("i", BIGINT_TYPE_HANDLE, BIGINT))); + assertThat(stats) + .returns(Estimate.of(1.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testNullsFraction() + throws Exception + { + JdbcColumnHandle custkeyColumnHandle = CUSTOMER_COLUMNS.get(0); + TableStatistics stats = collectStats( + "SELECT CASE custkey % 3 WHEN 0 THEN NULL ELSE custkey END FROM " + TEST_SCHEMA + ".customer", + ImmutableList.of(custkeyColumnHandle)); + assertEquals(stats.getRowCount(), Estimate.of(1500)); + + ColumnStatistics columnStatistics = stats.getColumnStatistics().get(custkeyColumnHandle); + assertThat(columnStatistics.getNullsFraction().getValue()).isCloseTo(1.0 / 3, withinPercentage(1)); + } + + @Test + public void testAverageColumnLength() + throws Exception + { + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("v3_in_3", 3), + createVarcharJdbcColumnHandle("v3_in_42", 42), + createVarcharJdbcColumnHandle("single_10v_value", 10), + createVarcharJdbcColumnHandle("half_10v_value", 10), + createVarcharJdbcColumnHandle("half_distinct_20v_value", 20), + createVarcharJdbcColumnHandle("all_nulls", 10)); + + assertThat( + collectStats( + "SELECT " + + " custkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " CASE custkey WHEN 1 THEN '0123456789' ELSE NULL END single_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN '0123456789' ELSE NULL END half_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN CAST((1000000 - custkey) * (1000000 - custkey) AS varchar(20)) ELSE NULL END half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM " + TEST_SCHEMA + ".customer " + + "ORDER BY custkey LIMIT 100", + columns)) + .returns(Estimate.of(100), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(columns.get(0), statsCloseTo(100.0, 0.0, 800)) + .hasEntrySatisfying(columns.get(1), statsCloseTo(1.0, 0.0, 700.0)) + .hasEntrySatisfying(columns.get(2), statsCloseTo(1.0, 0.0, 700)) + .hasEntrySatisfying(columns.get(3), statsCloseTo(1.0, 0.99, 14)) + .hasEntrySatisfying(columns.get(4), statsCloseTo(1.0, 0.5, 700)) + .hasEntrySatisfying(columns.get(5), statsCloseTo(51, 0.5, 800)) + .satisfies(stats -> assertNull(stats.get(columns.get(6)))); + } + + @Test + public void testView() + throws Exception + { + String tableName = "test_stats_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE OR REPLACE VIEW " + schemaAndTable + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP VIEW IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testMaterializedView() + throws Exception + { + String tableName = "test_stats_materialized_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE MATERIALIZED VIEW " + schemaAndTable + + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + executeInRedshift("REFRESH MATERIALIZED VIEW " + schemaAndTable); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP MATERIALIZED VIEW " + schemaAndTable); + } + } + + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) + .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) + .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) + .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) + .put("nans_only double", List.of("nan()", "nan()")) + .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + executeInRedshift("ANALYZE VERBOSE " + TEST_SCHEMA + "." + table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + "('only_negative_infinity', null, 1, 0, null, null, null)," + + "('only_positive_infinity', null, 1, 0, null, null, null)," + + "('mixed_infinities', null, 2, 0, null, null, null)," + + "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + + "('nans_only', null, 1.0, 0.5, null, null, null)," + + "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "(null, null, null, null, 4, null, null)"); + } + } + + /** + * Assert that the given column is within 5% of each statistic in the parameters, and that it has no range + */ + private static Consumer statsCloseTo(double distinctValues, double nullsFraction, double dataSize) + { + return stats -> { + SoftAssertions softly = new SoftAssertions(); + + softly.assertThat(stats.getDistinctValuesCount().getValue()) + .isCloseTo(distinctValues, withinPercentage(5.0)); + + softly.assertThat(stats.getNullsFraction().getValue()) + .isCloseTo(nullsFraction, withinPercentage(5.0)); + + softly.assertThat(stats.getDataSize().getValue()) + .isCloseTo(dataSize, withinPercentage(5.0)); + + softly.assertThat(stats.getRange()).isEmpty(); + softly.assertAll(); + }; + } + + private TableStatistics collectStats(String values, List columnHandles) + throws Exception + { + String tableName = "testredshiftstatisticsreader_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " AS " + values); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + return statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columnHandles); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + private static JdbcColumnHandle createVarcharJdbcColumnHandle(String name, int length) + { + return new JdbcColumnHandle( + name, + new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(length), Optional.empty(), Optional.empty(), Optional.empty()), + VarcharType.createVarcharType(length)); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java new file mode 100644 index 000000000000..26938c3b6532 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java @@ -0,0 +1,994 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.base.Utf8; +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingSession; +import io.trino.testing.datatype.CreateAndInsertDataSetup; +import io.trino.testing.datatype.CreateAsSelectDataSetup; +import io.trino.testing.datatype.DataSetup; +import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.JdbcSqlExecutor; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.testing.sql.TrinoSqlExecutor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.io.BaseEncoding.base16; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_VARCHAR; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.createTimeType; +import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftTypeMapping + extends AbstractTestQueryFramework +{ + private static final ZoneId testZone = TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId(); + + private final ZoneId jvmZone = ZoneId.systemDefault(); + private final LocalDateTime timeGapInJvmZone = LocalDate.EPOCH.atStartOfDay(); + private final LocalDateTime timeDoubledInJvmZone = LocalDateTime.of(2018, 10, 28, 1, 33, 17, 456_789_000); + + // using two non-JVM zones so that we don't need to worry what the backend's system zone is + + // no DST in 1970, but has DST in later years (e.g. 2018) + private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); + private final LocalDateTime timeGapInVilnius = LocalDateTime.of(2018, 3, 25, 3, 17, 17); + private final LocalDateTime timeDoubledInVilnius = LocalDateTime.of(2018, 10, 28, 3, 33, 33, 333_333_000); + + // Size of offset changed since 1970-01-01, no DST + private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); + private final LocalDateTime timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7); + + private final LocalDate dayOfMidnightGapInJvmZone = LocalDate.EPOCH; + private final LocalDate dayOfMidnightGapInVilnius = LocalDate.of(1983, 4, 1); + private final LocalDate dayAfterMidnightSetBackInVilnius = LocalDate.of(1983, 10, 1); + + @BeforeClass + public void checkRanges() + { + // Timestamps + checkIsGap(jvmZone, timeGapInJvmZone); + checkIsDoubled(jvmZone, timeDoubledInJvmZone); + checkIsGap(vilnius, timeGapInVilnius); + checkIsDoubled(vilnius, timeDoubledInVilnius); + checkIsGap(kathmandu, timeGapInKathmandu); + + // Times + checkIsGap(jvmZone, LocalTime.of(0, 0, 0).atDate(LocalDate.EPOCH)); + + // Dates + checkIsGap(jvmZone, dayOfMidnightGapInJvmZone.atStartOfDay()); + checkIsGap(vilnius, dayOfMidnightGapInVilnius.atStartOfDay()); + checkIsDoubled(vilnius, dayAfterMidnightSetBackInVilnius.atStartOfDay().minusNanos(1)); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), List.of()); + } + + @Test + public void testBasicTypes() + { + // Assume that if these types work at all, they have standard semantics. + SqlDataTypeTest.create() + .addRoundTrip("boolean", "true", BOOLEAN, "true") + .addRoundTrip("boolean", "false", BOOLEAN, "false") + .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") + .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") + .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") + .addRoundTrip("double", "123.45", DOUBLE, "DOUBLE '123.45'") + .addRoundTrip("real", "123.45", REAL, "REAL '123.45'") + // If we map tinyint to smallint: + .addRoundTrip("tinyint", "5", SMALLINT, "SMALLINT '5'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_basic_types")); + } + + @Test + public void testVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("varchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("varchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("varchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("varchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("varchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("varchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("varchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_varchar")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_varchar")); + } + + @Test + public void testChar() + { + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("char(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("char(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_char")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_char")); + + // Test with types larger than Redshift's char(max) + SqlDataTypeTest.create() + .addRoundTrip("char(65535)", "'varchar max'", createVarcharType(65535), format("CAST('varchar max%s' AS varchar(65535))", " ".repeat(65535 - "varchar max".length()))) + .addRoundTrip("char(4136)", "'攻殻機動隊'", createVarcharType(4136), format("CAST('%s' AS varchar(4136))", padVarchar(4136).apply("攻殻機動隊"))) + .addRoundTrip("char(4104)", "'隊'", createVarcharType(4104), format("CAST('%s' AS varchar(4104))", padVarchar(4104).apply("隊"))) + .addRoundTrip("char(4112)", "'😂'", createVarcharType(4112), format("CAST('%s' AS varchar(4112))", padVarchar(4112).apply("😂"))) + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("char(4106)", "'text_a'", createVarcharType(4106), format("CAST('%s' AS varchar(4106))", padVarchar(4106).apply("text_a"))) + .addRoundTrip("char(4351)", "'text_b'", createVarcharType(4351), format("CAST('%s' AS varchar(4351))", padVarchar(4351).apply("text_b"))) + .addRoundTrip("char(8192)", "'char max'", createVarcharType(8192), format("CAST('%s' AS varchar(8192))", padVarchar(8192).apply("char max"))) + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_large_char")); + } + + /** + * Test handling of data outside Redshift's normal bounds. + * + *

Redshift sometimes returns unbounded {@code VARCHAR} data, apparently + * when it returns directly from a Postgres function. + */ + @Test + public void testPostgresText() + { + try (TestView view1 = new TestView("postgres_text_view", "SELECT lpad('x', 1)"); + TestView view2 = new TestView("pg_catalog_view", "SELECT relname FROM pg_class")) { + // Test data and type from a function + assertThat(query(format("SELECT * FROM %s", view1.name))) + .matches("VALUES CAST('x' AS varchar)"); + + // Test the type of an internal table + assertThat(query(format("SELECT * FROM %s LIMIT 1", view2.name))) + .hasOutputTypes(List.of(createUnboundedVarcharType())); + } + } + + // Make sure that Redshift still maps NCHAR and NVARCHAR to CHAR and VARCHAR. + @Test + public void checkNCharAndNVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("nvarchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("nvarchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("nvarchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("nvarchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("nvarchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("nvarchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("nvarchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("nvarchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nvarchar")); + + SqlDataTypeTest.create() + .addRoundTrip("nchar(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("nchar(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("nchar(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nchar")); + } + + @Test + public void testUnicodeChar() // Redshift doesn't allow multibyte chars in CHAR + { + try (TestTable table = testTable("test_multibyte_char", "(c char(32))")) { + assertQueryFails( + format("INSERT INTO %s VALUES ('\u968A')", table.getName()), + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + assertCreateFails( + "test_multibyte_char_ctas", + "AS SELECT CAST('\u968A' AS char(32)) c", + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + // Make sure Redshift really doesn't allow multibyte characters in CHAR + @Test + public void checkUnicodeCharInRedshift() + { + try (TestTable table = testTable("check_multibyte_char", "(c char(32))")) { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("INSERT INTO %s VALUES ('\u968a')", table.getName()))) + .getCause() + .isInstanceOf(SQLException.class) + .hasMessageContaining("CHAR string contains invalid ASCII character"); + } + } + + @Test + public void testOversizedCharacterTypes() + { + // Test that character types too large for Redshift map to the maximum size + SqlDataTypeTest.create() + .addRoundTrip("varchar", "'unbounded'", createVarcharType(65535), "CAST('unbounded' AS varchar(65535))") + .addRoundTrip(format("varchar(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized varchar'", createVarcharType(65535), "CAST('oversized varchar' AS varchar(65535))") + .addRoundTrip(format("char(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized char'", createVarcharType(65535), format("CAST('%s' AS varchar(65535))", padVarchar(65535).apply("oversized char"))) + .execute(getQueryRunner(), trinoCreateAsSelect("oversized_character_types")); + } + + @Test + public void testVarbinary() + { + // Redshift's VARBYTE is mapped to Trino VARBINARY. Redshift does not have VARBINARY type. + SqlDataTypeTest.create() + // varbyte + .addRoundTrip("varbyte", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbyte", "to_varbyte('', 'hex')", VARBINARY, "X''") + .addRoundTrip("varbyte", utf8VarbyteLiteral("hello"), VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Piękna łąka w 東京都"), VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbyte", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbyte", "to_varbyte('000000000000', 'hex')", VARBINARY, "X'000000000000'") + .addRoundTrip("varbyte(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbyte(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // varbinary + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbinary(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // binary varying + .addRoundTrip("binary varying", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("binary varying", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("binary varying", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("binary varying(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("binary varying(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + .execute(getQueryRunner(), redshiftCreateAndInsert("test_varbinary")); + + SqlDataTypeTest.create() + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", "X''", VARBINARY, "X''") + .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); + } + + private static String utf8VarbyteLiteral(String string) + { + return format("to_varbyte('%s', 'hex')", base16().encode(string.getBytes(UTF_8))); + } + + @Test + public void testDecimal() + { + SqlDataTypeTest.create() + .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") + .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") + .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(31, 0)", "CAST('2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(31, 0)", "CAST('-2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('-2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(3, 0)", "NULL", createDecimalType(3, 0), "CAST(NULL AS decimal(3, 0))") + .addRoundTrip("decimal(31, 0)", "NULL", createDecimalType(31, 0), "CAST(NULL AS decimal(31, 0))") + .execute(getQueryRunner(), redshiftCreateAndInsert("test_decimal")) + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + } + + @Test + public void testRedshiftDecimalCutoff() + { + String columns = "(d19 decimal(19, 0), d18 decimal(19, 18), d0 decimal(19, 19))"; + try (TestTable table = testTable("test_decimal_range", columns)) { + assertQueryFails( + format("INSERT INTO %s (d19) VALUES (DECIMAL'9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 0\\)$"); + assertQueryFails( + format("INSERT INTO %s (d18) VALUES (DECIMAL'9.991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 18\\)$"); + assertQueryFails( + format("INSERT INTO %s (d0) VALUES (DECIMAL'.9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 19\\)$"); + } + } + + @Test + public void testRedshiftDecimalScaleLimit() + { + assertCreateFails( + "test_overlarge_decimal_scale", + "(d DECIMAL(38, 38))", + "^ERROR: DECIMAL scale 38 must be between 0 and 37$"); + } + + @Test + public void testUnsupportedTrinoDataTypes() + { + assertCreateFails( + "test_unsupported_type", + "(col json)", + "Unsupported column type: json"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testDate(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0001-01-01'", DATE, "DATE '0001-01-01'") // first day of AD + .addRoundTrip("date", "DATE '1500-01-01'", DATE, "DATE '1500-01-01'") // sometime before julian->gregorian switch + .addRoundTrip("date", "DATE '1600-01-01'", DATE, "DATE '1600-01-01'") // long ago but after julian->gregorian switch + .addRoundTrip("date", "DATE '1952-04-03'", DATE, "DATE '1952-04-03'") // before epoch + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") + .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") // after epoch + .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer in northern hemisphere (possible DST) + .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter in northern hemisphere (possible DST in southern hemisphere) + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") // day of midnight gap in JVM + .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") // day of midnight gap in Vilnius + .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") // day after midnight setback in Vilnius + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '-0100-01-01'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0101-01-01 BC'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTime(ZoneId sessionZone) + { + // Redshift gets bizarre errors if you try to insert after + // specifying precision for a time column. + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + timeTypeTests("time(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "time_from_trino")); + timeTypeTests("time") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("time_from_jdbc")); + } + + private static SqlDataTypeTest timeTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIME '00:00:00.000000'", createTimeType(6), "TIME '00:00:00.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '00:13:42.000000'", createTimeType(6), "TIME '00:13:42.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '01:33:17.000000'", createTimeType(6), "TIME '01:33:17.000000'") + .addRoundTrip(inputType, "TIME '03:17:17.000000'", createTimeType(6), "TIME '03:17:17.000000'") + .addRoundTrip(inputType, "TIME '10:01:17.100000'", createTimeType(6), "TIME '10:01:17.100000'") + .addRoundTrip(inputType, "TIME '13:18:03.000000'", createTimeType(6), "TIME '13:18:03.000000'") + .addRoundTrip(inputType, "TIME '14:18:03.000000'", createTimeType(6), "TIME '14:18:03.000000'") + .addRoundTrip(inputType, "TIME '15:18:03.000000'", createTimeType(6), "TIME '15:18:03.000000'") + .addRoundTrip(inputType, "TIME '16:18:03.123456'", createTimeType(6), "TIME '16:18:03.123456'") + .addRoundTrip(inputType, "TIME '19:01:17.000000'", createTimeType(6), "TIME '19:01:17.000000'") + .addRoundTrip(inputType, "TIME '20:01:17.000000'", createTimeType(6), "TIME '20:01:17.000000'") + .addRoundTrip(inputType, "TIME '21:01:17.000001'", createTimeType(6), "TIME '21:01:17.000001'") + .addRoundTrip(inputType, "TIME '22:59:59.000000'", createTimeType(6), "TIME '22:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.000000'", createTimeType(6), "TIME '23:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.999999'", createTimeType(6), "TIME '23:59:59.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestamp(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + // Redshift doesn't allow timestamp precision to be specified + timestampTypeTests("timestamp(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "timestamp_from_trino")); + timestampTypeTests("timestamp") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("timestamp_from_jdbc")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("timestamp(6)", "TIMESTAMP '-0100-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("timestamp", "TIMESTAMP '0101-01-01 00:00:00 BC'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + private static SqlDataTypeTest timestampTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '0001-01-01 00:00:00.000000'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1500-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1500-01-01 00:00:00.000000'") // sometime before julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1600-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1600-01-01 00:00:00.000000'") // long ago but after julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1958-01-01 13:18:03.123456'", createTimestampType(6), "TIMESTAMP '1958-01-01 13:18:03.123456'") // before epoch + .addRoundTrip(inputType, "TIMESTAMP '2019-03-18 10:09:17.987654'", createTimestampType(6), "TIMESTAMP '2019-03-18 10:09:17.987654'") // after epoch + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789'", createTimestampType(6), "TIMESTAMP '2018-10-28 01:33:17.456789'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333'", createTimestampType(6), "TIMESTAMP '2018-10-28 03:33:33.333333'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.000000'") // time gap in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000'", createTimestampType(6), "TIMESTAMP '2018-03-25 03:17:17.000000'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000'", createTimestampType(6), "TIMESTAMP '1986-01-01 00:13:07.000000'") // time gap in Kathmandu + // Full time precision + .addRoundTrip(inputType, "TIMESTAMP '1969-12-31 23:59:59.999999'", createTimestampType(6), "TIMESTAMP '1969-12-31 23:59:59.999999'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999999'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestampWithTimeZone(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + + SqlDataTypeTest.create() + // test arbitrary time for all supported precisions + .addRoundTrip("timestamp(0) with time zone", "TIMESTAMP '2022-09-27 12:34:56 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.000000 UTC'") + .addRoundTrip("timestamp(1) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(2) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.120000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(4) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1234 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123400 UTC'") + .addRoundTrip("timestamp(5) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12345 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123450 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'") + + // short timestamp with time zone + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '-4712-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999000 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1986-01-01 00:13:07 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456000 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333000 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247000 UTC'") // max value in Trino + .addRoundTrip("timestamp(3) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + + // long timestamp with time zone + // .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu (long timestamp_tz) + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'") // max value in Trino + .addRoundTrip("timestamp(6) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + + redshiftTimestampWithTimeZoneTests("timestamptz") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + redshiftTimestampWithTimeZoneTests("timestamp with time zone") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + } + + private static SqlDataTypeTest redshiftTimestampWithTimeZoneTests(String inputType) + { + return SqlDataTypeTest.create() + // .addRoundTrip(inputType, "TIMESTAMP '4713-01-01 00:00:00 BC' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1582-10-04 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1582-10-05 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-14 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-15 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '73326-09-11 20:14:45.247999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'"); // max value in Trino + } + + @Test + public void testTimestampWithTimeZoneCoercion() + { + SqlDataTypeTest.create() + // short timestamp with time zone + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.12341 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round up, end result rounds down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1235 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.124000 UTC'") // round up + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111000 UTC'") // max precision + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999499999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + + // long timestamp with time zone + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234561 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // round down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // nanoc round up, end result rounds down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234565 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123457 UTC'") // round up + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111222 UTC'") // max precision + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999999499999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .execute(getQueryRunner(), trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + } + + @Test + public void testTimestampWithTimeZoneOverflow() + { + // The min timestamp with time zone value in Trino is smaller than Redshift + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(3) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(6) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752000 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + + // The max timestamp with time zone value in Redshift is larger than Trino + try (TestTable table = new TestTable(getRedshiftExecutor(), TEST_SCHEMA + ".timestamp_tz_max", "(ts timestamptz)", ImmutableList.of("TIMESTAMP '294276-12-31 23:59:59' AT TIME ZONE 'UTC'"))) { + assertThatThrownBy(() -> query("SELECT * FROM " + table.getName())) + .hasMessage("Millis overflow: 9224318015999000"); + } + } + + @DataProvider(name = "datetime_test_parameters") + public Object[][] dataProviderForDatetimeTests() + { + return new Object[][] { + {UTC}, + {jvmZone}, + {vilnius}, + {kathmandu}, + {testZone}, + }; + } + + @Test + public void testUnsupportedDateTimeTypes() + { + assertCreateFails( + "test_time_with_time_zone", + "(value TIME WITH TIME ZONE)", + "Unsupported column type: (?i)time.* with time zone"); + } + + @Test + public void testDateLimits() + { + // We can't test the exact date limits because Redshift doesn't say + // what they are, so we test one date on either side. + try (TestTable table = testTable("test_date_limits", "(d date)")) { + // First day of smallest year that Redshift supports (based on its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '-4712-01-01')", table.getName()), 1); + // Small date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '-4713-06-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"4714-06-01 BC\""); + + // Last day of the largest year that Redshift supports (based on in its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '294275-12-31')", table.getName()), 1); + // Large date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '5875000-01-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"5875000-01-01 AD\""); + } + } + + @Test + public void testLimitedTimePrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIME '\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIME '00:00:00'".length() - (input.contains(".") ? 1 : 0), + List.of( + // No rounding + new TestCase("TIME '00:00:00'", "TIME '00:00:00'"), + new TestCase("TIME '00:00:00.000000'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.123456'", "TIME '00:00:00.123456'"), + new TestCase("TIME '12:34:56'", "TIME '12:34:56'"), + new TestCase("TIME '12:34:56.123456'", "TIME '12:34:56.123456'"), + new TestCase("TIME '23:59:59'", "TIME '23:59:59'"), + new TestCase("TIME '23:59:59.9'", "TIME '23:59:59.9'"), + new TestCase("TIME '23:59:59.999'", "TIME '23:59:59.999'"), + new TestCase("TIME '23:59:59.999999'", "TIME '23:59:59.999999'"), + // round down + new TestCase("TIME '00:00:00.0000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '12:34:56.1234561'", "TIME '12:34:56.123456'"), + // round down, maximal value + new TestCase("TIME '00:00:00.0000004'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000449'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000444449'", "TIME '00:00:00.000000'"), + // round up, minimal value + new TestCase("TIME '00:00:00.0000005'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500000'", "TIME '00:00:00.000001'"), + // round up, maximal value + new TestCase("TIME '00:00:00.0000009'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999999'", "TIME '00:00:00.000001'"), + // round up to next day, minimal value + new TestCase("TIME '23:59:59.9999995'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500000'", "TIME '00:00:00.000000'"), + // round up to next day, maximal value + new TestCase("TIME '23:59:59.9999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999999'", "TIME '00:00:00.000000'"), + // don't round to next day (round down near upper bound) + new TestCase("TIME '23:59:59.9999994'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499999'", "TIME '23:59:59.999999'"))); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_time_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + @Test + public void testLimitedTimestampPrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIMESTAMP '\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIMESTAMP '0000-00-00 00:00:00'".length() - (input.contains(".") ? 1 : 0), + // No rounding + new TestCase("TIMESTAMP '1970-01-01 00:00:00'", "TIMESTAMP '1970-01-01 00:00:00'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56'", "TIMESTAMP '2020-11-03 12:34:56'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + + new TestCase("TIMESTAMP '1970-01-01 00:00:00.123456'", "TIMESTAMP '1970-01-01 00:00:00.123456'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.123456'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59'", "TIMESTAMP '1969-12-31 23:59:59'"), + + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9'", "TIMESTAMP '1970-01-01 23:59:59.9'"), + new TestCase("TIMESTAMP '2020-11-03 23:59:59.999'", "TIMESTAMP '2020-11-03 23:59:59.999'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + // round down + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000001'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000000001'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.1234561'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + // round down, maximal value + new TestCase("TIMESTAMP '2020-11-03 00:00:00.0000004'", "TIMESTAMP '2020-11-03 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000449'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000444449'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // round up, minimal value + new TestCase("TIMESTAMP '1970-01-01 00:00:00.0000005'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000500'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000500000'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + // round up, maximal value + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000009'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000999'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000999999'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + // round up to next year, minimal value + new TestCase("TIMESTAMP '2020-12-31 23:59:59.9999995'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999500'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999500000'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + // round up to next day/year, maximal value + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9999999'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999999'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999999999'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // don't round to next year (round down near upper bound) + new TestCase("TIMESTAMP '1969-12-31 23:59:59.9999994'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999499'", "TIMESTAMP '1970-01-01 23:59:59.999999'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999499999'", "TIMESTAMP '2020-12-31 23:59:59.999999'")); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_timestamp_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, TestCase... testCases) + { + return groupTestCasesByInput(inputRegex, classifier, Arrays.asList(testCases)); + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, List testCases) + { + return testCases.stream() + .peek(test -> { + if (!test.input().matches(inputRegex)) { + throw new RuntimeException("Bad test case input format: " + test.input()); + } + }) + .collect(groupingBy(classifier.compose(TestCase::input))); + } + + private void runTestCases(String tableName, List testCases) + { + // Must use CTAS instead of TestTable because if the table is created before the insert, + // the type mapping will treat it as TIME(6) no matter what it was created as. + getTrinoExecutor().execute(format( + "CREATE TABLE %s AS SELECT * FROM (VALUES %s) AS t (id, value)", + tableName, + testCases.stream() + .map(testCase -> format("(%d, %s)", testCase.id(), testCase.input())) + .collect(joining("), (", "(", ")")))); + try { + assertQuery( + format("SELECT value FROM %s ORDER BY id", tableName), + testCases.stream() + .map(TestCase::expected) + .collect(joining("), (", "VALUES (", ")"))); + } + finally { + getTrinoExecutor().execute("DROP TABLE " + tableName); + } + } + + @Test + public static void checkIllegalRedshiftTimePrecision() + { + assertRedshiftCreateFails( + "check_redshift_time_precision_error", + "(t TIME(6))", + "ERROR: time column does not support precision."); + } + + @Test + public static void checkIllegalRedshiftTimestampPrecision() + { + assertRedshiftCreateFails( + "check_redshift_timestamp_precision_error", + "(t TIMESTAMP(6))", + "ERROR: timestamp column does not support precision."); + } + + /** + * Assert that a {@code CREATE TABLE} statement made from Redshift fails, + * and drop the table if it doesn't fail. + */ + private static void assertRedshiftCreateFails(String tableNamePrefix, String tableBody, String message) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("CREATE TABLE %s %s", tableName, tableBody))) + .getCause() + .as("Redshift create fails for %s %s", tableName, tableBody) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE IF EXISTS " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + /** + * Assert that a {@code CREATE TABLE} statement fails, and drop the table + * if it doesn't fail. + */ + private void assertCreateFails(String tableNamePrefix, String tableBody, String expectedMessageRegExp) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertQueryFails(format("CREATE TABLE %s %s", tableName, tableBody), expectedMessageRegExp); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + private DataSetup trinoCreateAsSelect(String tableNamePrefix) + { + return trinoCreateAsSelect(getQueryRunner().getDefaultSession(), tableNamePrefix); + } + + private DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) + { + return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + private static DataSetup redshiftCreateAndInsert(String tableNamePrefix) + { + return new CreateAndInsertDataSetup(getRedshiftExecutor(), TEST_SCHEMA + "." + tableNamePrefix); + } + + /** + * Create a table in the test schema using the JDBC. + * + *

Creating a test table normally doesn't use the correct schema. + */ + private static TestTable testTable(String namePrefix, String body) + { + return new TestTable(getRedshiftExecutor(), TEST_SCHEMA + "." + namePrefix, body); + } + + private SqlExecutor getTrinoExecutor() + { + return new TrinoSqlExecutor(getQueryRunner()); + } + + private static SqlExecutor getRedshiftExecutor() + { + Properties properties = new Properties(); + properties.setProperty("user", JDBC_USER); + properties.setProperty("password", JDBC_PASSWORD); + return new JdbcSqlExecutor(JDBC_URL, properties); + } + + private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).isEmpty(), + "Expected %s to be a gap in %s", dateTime, zone); + } + + private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).size() == 2, + "Expected %s to be doubled in %s", dateTime, zone); + } + + private static Function padVarchar(int length) + { + // Add the same padding as RedshiftClient.writeCharAsVarchar, but start from String, not Slice + return (input) -> input + " ".repeat(length - Utf8.encodedLength(input)); + } + + /** + * A pair of input and expected output from a test. + * Each instance has a unique ID. + */ + private static class TestCase + { + private static final AtomicInteger LAST_ID = new AtomicInteger(); + + private final int id; + private final String input; + private final String expected; + + private TestCase(String input, String expected) + { + this.id = LAST_ID.incrementAndGet(); + this.input = input; + this.expected = expected; + } + + public int id() + { + return this.id; + } + + public String input() + { + return this.input; + } + + public String expected() + { + return this.expected; + } + } + + private static class TestView + implements AutoCloseable + { + final String name; + + TestView(String namePrefix, String definition) + { + name = requireNonNull(namePrefix) + "_" + randomNameSuffix(); + executeInRedshift(format("CREATE VIEW %s.%s AS %s", TEST_SCHEMA, name, definition)); + } + + @Override + public void close() + { + executeInRedshift(format("DROP VIEW IF EXISTS %s.%s", TEST_SCHEMA, name)); + } + } +}