From 16570473cd225513154585159f1bef63186c1bcd Mon Sep 17 00:00:00 2001 From: Sam Guymer Date: Tue, 27 Dec 2022 23:23:59 +1000 Subject: [PATCH] Improve MySQL Java time instances --- docker-compose.yml | 4 +- init/mysql/test-table.sql | 11 ++++ init/{ => postgres}/test-db.sql | 0 .../src/main/scala/doobie/hi/connection.scala | 51 +++++++++------- .../src/main/scala/doobie/util/analysis.scala | 61 ++++++++++--------- .../doobie/mysql/JavaTimeInstances.scala | 53 +++++++++------- .../test/scala/doobie/mysql/CheckSuite.scala | 58 ++++++++++-------- .../doobie/mysql/MySQLTestTransactor.scala | 3 +- .../test/scala/doobie/mysql/TypesSuite.scala | 35 +++++------ .../util/arbitraries/TimeArbitraries.scala | 32 ++++------ .../test/scala/doobie/postgres/LOSuite.scala | 2 +- 11 files changed, 173 insertions(+), 137 deletions(-) create mode 100644 init/mysql/test-table.sql rename init/{ => postgres}/test-db.sql (100%) diff --git a/docker-compose.yml b/docker-compose.yml index 845bc1658..fdb91d089 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: ports: - 5432:5432 volumes: - - ./init/:/docker-entrypoint-initdb.d/ + - ./init/postgres/:/docker-entrypoint-initdb.d/ mysql: image: mysql:5.7-debian @@ -20,3 +20,5 @@ services: MYSQL_DATABASE: world ports: - 3306:3306 + volumes: + - ./init/mysql/:/docker-entrypoint-initdb.d/ diff --git a/init/mysql/test-table.sql b/init/mysql/test-table.sql new file mode 100644 index 000000000..8852905cf --- /dev/null +++ b/init/mysql/test-table.sql @@ -0,0 +1,11 @@ + +CREATE TABLE IF NOT EXISTS test ( + c_integer INTEGER NOT NULL, + c_varchar VARCHAR(1024) NOT NULL, + c_date DATE NOT NULL, + c_datetime DATETIME(6) NOT NULL, + c_time TIME(6) NOT NULL, + c_timestamp TIMESTAMP(6) NOT NULL +); +INSERT INTO test(c_integer, c_varchar, c_date, c_datetime, c_time, c_timestamp) +VALUES (123, 'str', '2019-02-13', '2019-02-13 22:03:21.051', '22:03:21.051', '2019-02-13 22:03:21.051'); diff --git a/init/test-db.sql b/init/postgres/test-db.sql similarity index 100% rename from init/test-db.sql rename to init/postgres/test-db.sql diff --git a/modules/core/src/main/scala/doobie/hi/connection.scala b/modules/core/src/main/scala/doobie/hi/connection.scala index 688e224c8..d4e2d5f5f 100644 --- a/modules/core/src/main/scala/doobie/hi/connection.scala +++ b/modules/core/src/main/scala/doobie/hi/connection.scala @@ -4,26 +4,28 @@ package doobie.hi -import doobie.util.compat.propertiesToScala +import cats.Foldable +import cats.data.Ior +import cats.effect.kernel.syntax.monadCancel._ +import cats.syntax.all._ +import doobie.enumerated.AutoGeneratedKeys import doobie.enumerated.Holdability -import doobie.enumerated.ResultSetType +import doobie.enumerated.Nullability import doobie.enumerated.ResultSetConcurrency +import doobie.enumerated.ResultSetType import doobie.enumerated.TransactionIsolation -import doobie.enumerated.AutoGeneratedKeys -import doobie.util.{ Read, Write } import doobie.util.analysis.Analysis +import doobie.util.analysis.ColumnMeta +import doobie.util.analysis.ParameterMeta +import doobie.util.compat.propertiesToScala import doobie.util.stream.repeatEvalChunks +import doobie.util.{ Get, Put, Read, Write } +import fs2.Stream +import fs2.Stream.{ eval, bracket } import java.sql.{ Savepoint, PreparedStatement, ResultSet } - import scala.collection.immutable.Map -import cats.Foldable -import cats.syntax.all._ -import cats.effect.kernel.syntax.monadCancel._ -import fs2.Stream -import fs2.Stream.{ eval, bracket } - /** * Module of high-level constructors for `ConnectionIO` actions. * @group Modules @@ -92,24 +94,29 @@ object connection { * readable resultset row type `B`. */ def prepareQueryAnalysis[A: Write, B: Read](sql: String): ConnectionIO[Analysis] = - prepareStatement(sql) { - (HPS.getParameterMappings[A], HPS.getColumnMappings[B]) mapN (Analysis(sql, _, _)) - } + prepareAnalysis(sql, HPS.getParameterMappings[A], HPS.getColumnMappings[B]) def prepareQueryAnalysis0[B: Read](sql: String): ConnectionIO[Analysis] = - prepareStatement(sql) { - HPS.getColumnMappings[B] map (cm => Analysis(sql, Nil, cm)) - } + prepareAnalysis(sql, FPS.pure(Nil), HPS.getColumnMappings[B]) def prepareUpdateAnalysis[A: Write](sql: String): ConnectionIO[Analysis] = - prepareStatement(sql) { - HPS.getParameterMappings[A] map (pm => Analysis(sql, pm, Nil)) - } + prepareAnalysis(sql, HPS.getParameterMappings[A], FPS.pure(Nil)) def prepareUpdateAnalysis0(sql: String): ConnectionIO[Analysis] = - prepareStatement(sql) { - Analysis(sql, Nil, Nil).pure[PreparedStatementIO] + prepareAnalysis(sql, FPS.pure(Nil), FPS.pure(Nil)) + + private def prepareAnalysis( + sql: String, + params: PreparedStatementIO[List[(Put[_], Nullability.NullabilityKnown) Ior ParameterMeta]], + columns: PreparedStatementIO[List[(Get[_], Nullability.NullabilityKnown) Ior ColumnMeta]], + ) = { + val mappings = prepareStatement(sql) { + (params, columns).tupled } + (HC.getMetaData(FDMD.getDriverName), mappings).mapN { case (driver, (p, c)) => + Analysis(driver, sql, p, c) + } + } /** @group Statements */ diff --git a/modules/core/src/main/scala/doobie/util/analysis.scala b/modules/core/src/main/scala/doobie/util/analysis.scala index e0d06696e..ff0524bd3 100644 --- a/modules/core/src/main/scala/doobie/util/analysis.scala +++ b/modules/core/src/main/scala/doobie/util/analysis.scala @@ -20,28 +20,9 @@ object analysis { /** Metadata for the JDBC end of a column/parameter mapping. */ final case class ColumnMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String) - object ColumnMeta { - def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String): ColumnMeta = { - new ColumnMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, name) - } - } /** Metadata for the JDBC end of a column/parameter mapping. */ final case class ParameterMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode) - object ParameterMeta { - def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode): ParameterMeta = { - new ParameterMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, mode) - } - } - - private def tweakJdbcType(jdbcType: JdbcType, vendorTypeName: String) = jdbcType match { - // the Postgres driver does not return *WithTimezone types but they are pretty much required for proper analysis - // https://github.com/pgjdbc/pgjdbc/issues/2485 - // https://github.com/pgjdbc/pgjdbc/issues/1766 - case JdbcType.Time if vendorTypeName.compareToIgnoreCase("timetz") == 0 => JdbcType.TimeWithTimezone - case JdbcType.Timestamp if vendorTypeName.compareToIgnoreCase("timestamptz") == 0 => JdbcType.TimestampWithTimezone - case t => t - } sealed trait AlignmentError extends Product with Serializable { def tag: String @@ -122,30 +103,39 @@ object analysis { /** Compatibility analysis for the given statement and aligned mappings. */ final case class Analysis( + driver: String, sql: String, parameterAlignment: List[(Put[_], NullabilityKnown) Ior ParameterMeta], - columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta]) { + columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta] + ) { + + private val parameterAlignment_ = parameterAlignment.map(_.map { m => + m.copy(jdbcType = tweakMetaJdbcType(driver, m.jdbcType, vendorTypeName = m.vendorTypeName)) + }) + private val columnAlignment_ = columnAlignment.map(_.map { m => + m.copy(jdbcType = tweakMetaJdbcType(driver, m.jdbcType, vendorTypeName = m.vendorTypeName)) + }) def parameterMisalignments: List[ParameterMisalignment] = - parameterAlignment.zipWithIndex.collect { + parameterAlignment_.zipWithIndex.collect { case (Ior.Left(_), n) => ParameterMisalignment(n + 1, None) case (Ior.Right(p), n) => ParameterMisalignment(n + 1, Some(p)) } def parameterTypeErrors: List[ParameterTypeError] = - parameterAlignment.zipWithIndex.collect { + parameterAlignment_.zipWithIndex.collect { case (Ior.Both((j, n1), p), n) if !j.jdbcTargets.contains_(p.jdbcType) => ParameterTypeError(n + 1, j, n1, p.jdbcType, p.vendorTypeName) } def columnMisalignments: List[ColumnMisalignment] = - columnAlignment.zipWithIndex.collect { + columnAlignment_.zipWithIndex.collect { case (Ior.Left(j), n) => ColumnMisalignment(n + 1, Left(j)) case (Ior.Right(p), n) => ColumnMisalignment(n + 1, Right(p)) } def columnTypeErrors: List[ColumnTypeError] = - columnAlignment.zipWithIndex.collect { + columnAlignment_.zipWithIndex.collect { case (Ior.Both((j, n1), p), n) if !(j.jdbcSources.toList ++ j.fold(_.jdbcSourceSecondary.toList, _ => Nil)).contains_(p.jdbcType) => ColumnTypeError(n + 1, j, n1, p) case (Ior.Both((j, n1), p), n) if (p.jdbcType === JdbcType.JavaObject || p.jdbcType === JdbcType.Other) && !j.fold(_ => None, a => Some(a.schemaTypes.head)).contains_(p.vendorTypeName) => @@ -153,13 +143,13 @@ object analysis { } def columnTypeWarnings: List[ColumnTypeWarning] = - columnAlignment.zipWithIndex.collect { + columnAlignment_.zipWithIndex.collect { case (Ior.Both((j, n1), p), n) if j.fold(_.jdbcSourceSecondary.toList, _ => Nil).contains_(p.jdbcType) => ColumnTypeWarning(n + 1, j, n1, p) } def nullabilityMisalignments: List[NullabilityMisalignment] = - columnAlignment.zipWithIndex.collect { + columnAlignment_.zipWithIndex.collect { // We can't do anything helpful with NoNulls .. it means "might not be nullable" // case (Ior.Both((st, Nullable), ColumnMeta(_, _, NoNulls, col)), n) => NullabilityMisalignment(n + 1, col, st, NoNulls, Nullable) case (Ior.Both((st, NoNulls), ColumnMeta(_, _, Nullable, col)), n) => NullabilityMisalignment(n + 1, col, st.typeStack.last, Nullable, NoNulls) @@ -179,7 +169,7 @@ object analysis { /** Description of each parameter, paired with its errors. */ lazy val paramDescriptions: List[(String, List[AlignmentError])] = { val params: Block = - parameterAlignment.zipWithIndex.map { + parameterAlignment_.zipWithIndex.map { case (Ior.Both((j1, n1), ParameterMeta(j2, s2, _, _)), i) => List(f"P${i+1}%02d", show"${typeName(j1.typeStack.last, n1)}", " → ", j2.show.toUpperCase, show"($s2)") case (Ior.Left((j1, n1)), i) => List(f"P${i+1}%02d", show"${typeName(j1.typeStack.last, n1)}", " → ", "", "") case (Ior.Right( ParameterMeta(j2, s2, _, _)), i) => List(f"P${i+1}%02d", "", " → ", j2.show.toUpperCase, show"($s2)") @@ -193,7 +183,7 @@ object analysis { lazy val columnDescriptions: List[(String, List[AlignmentError])] = { import pretty._ val cols: Block = - columnAlignment.zipWithIndex.map { + columnAlignment_.zipWithIndex.map { case (Ior.Both((j1, n1), ColumnMeta(j2, s2, n2, m)), i) => List(f"C${i+1}%02d", m, j2.show.toUpperCase, show"(${s2.toString})", formatNullability(n2), " → ", typeName(j1.typeStack.last, n1)) case (Ior.Left((j1, n1)), i) => List(f"C${i+1}%02d", "", "", "", "", " → ", typeName(j1.typeStack.last, n1)) case (Ior.Right( ColumnMeta(j2, s2, n2, m)), i) => List(f"C${i+1}%02d", m, j2.show.toUpperCase, show"(${s2.toString})", formatNullability(n2), " → ", "") @@ -225,5 +215,20 @@ object analysis { case NullableUnknown => "NULL?" } + private val MySQLDriverName = "MySQL Connector/J" + + // tweaks to the types returned by JDBC to improve analysis + private def tweakMetaJdbcType(driver: String, jdbcType: JdbcType, vendorTypeName: String) = jdbcType match { + // the Postgres driver does not return *WithTimezone JDBC types for *tz column types + // https://github.com/pgjdbc/pgjdbc/issues/2485 + // https://github.com/pgjdbc/pgjdbc/issues/1766 + case JdbcType.Time if vendorTypeName.compareToIgnoreCase("timetz") == 0 => JdbcType.TimeWithTimezone + case JdbcType.Timestamp if vendorTypeName.compareToIgnoreCase("timestamptz") == 0 => JdbcType.TimestampWithTimezone + + // MySQL timestamp columns are returned as Timestamp + case JdbcType.Timestamp + if vendorTypeName.compareToIgnoreCase("timestamp") == 0 && driver == MySQLDriverName => JdbcType.TimestampWithTimezone + case t => t + } } diff --git a/modules/mysql/src/main/scala/doobie/mysql/JavaTimeInstances.scala b/modules/mysql/src/main/scala/doobie/mysql/JavaTimeInstances.scala index c8c5c7b55..1a281531d 100644 --- a/modules/mysql/src/main/scala/doobie/mysql/JavaTimeInstances.scala +++ b/modules/mysql/src/main/scala/doobie/mysql/JavaTimeInstances.scala @@ -4,50 +4,61 @@ package doobie.mysql -import java.time.OffsetDateTime -import java.time.ZoneOffset - import doobie.Meta import doobie.enumerated.{JdbcType => JT} import doobie.util.meta.MetaConstructors +import java.time.Instant +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.time.OffsetDateTime +import java.time.OffsetTime +import java.time.ZoneOffset + +/** + * Instances for JSR-310 date time types. + * + * Note that to ensure instants are preserved you may need to use one of the solutions described + * in [[https://docs.oracle.com/cd/E17952_01/connector-j-8.0-en/connector-j-time-instants.html]]. + */ trait JavaTimeInstances extends MetaConstructors { - implicit val JavaTimeOffsetDateTimeMeta: Meta[java.time.OffsetDateTime] = + implicit val JavaTimeOffsetDateTimeMeta: Meta[OffsetDateTime] = Basic.oneObject( - JT.Timestamp, - List(JT.VarChar, JT.Date, JT.Time), - classOf[java.time.OffsetDateTime] + JT.TimestampWithTimezone, + List(JT.VarChar, JT.Date, JT.Time, JT.Timestamp), + classOf[OffsetDateTime] ) - implicit val JavaTimeInstantMeta: Meta[java.time.Instant] = + implicit val JavaTimeInstantMeta: Meta[Instant] = JavaTimeOffsetDateTimeMeta.timap(_.toInstant)(OffsetDateTime.ofInstant(_, ZoneOffset.UTC)) - implicit val JavaTimeLocalDateTimeMeta: Meta[java.time.LocalDateTime] = + implicit val JavaTimeLocalDateTimeMeta: Meta[LocalDateTime] = Basic.oneObject( JT.Timestamp, - List(JT.VarChar, JT.Date, JT.Time), - classOf[java.time.LocalDateTime] + List(JT.VarChar, JT.Date, JT.Time, JT.TimestampWithTimezone), + classOf[LocalDateTime] ) - implicit val JavaTimeLocalDateMeta: Meta[java.time.LocalDate] = + implicit val JavaTimeLocalDateMeta: Meta[LocalDate] = Basic.oneObject( JT.Date, - List(JT.VarChar, JT.Time, JT.Timestamp), - classOf[java.time.LocalDate] + List(JT.VarChar, JT.Time, JT.Timestamp, JT.TimestampWithTimezone), + classOf[LocalDate] ) - implicit val JavaTimeLocalTimeMeta: Meta[java.time.LocalTime] = + implicit val JavaTimeLocalTimeMeta: Meta[LocalTime] = Basic.oneObject( JT.Time, - List(JT.Date, JT.Timestamp), - classOf[java.time.LocalTime] + List(JT.VarChar, JT.Date, JT.Timestamp, JT.TimestampWithTimezone), + classOf[LocalTime] ) - implicit val JavaTimeOffsetTimeMeta: Meta[java.time.OffsetTime] = + implicit val JavaTimeOffsetTimeMeta: Meta[OffsetTime] = Basic.oneObject( - JT.Timestamp, - List(JT.Date, JT.Time), - classOf[java.time.OffsetTime] + JT.TimestampWithTimezone, + List(JT.VarChar, JT.Date, JT.Time, JT.Timestamp), + classOf[OffsetTime] ) } diff --git a/modules/mysql/src/test/scala/doobie/mysql/CheckSuite.scala b/modules/mysql/src/test/scala/doobie/mysql/CheckSuite.scala index 9c18f0028..f74323016 100644 --- a/modules/mysql/src/test/scala/doobie/mysql/CheckSuite.scala +++ b/modules/mysql/src/test/scala/doobie/mysql/CheckSuite.scala @@ -16,49 +16,57 @@ class CheckSuite extends munit.FunSuite { import cats.effect.unsafe.implicits.global import MySQLTestTransactor.xa + // note selecting from a table because a value cannot be cast to a timestamp + // and casting returns a nullable column + test("OffsetDateTime Read typechecks") { - successRead[Option[OffsetDateTime]](sql"SELECT CAST('2019-02-13 22:03:21.051' AS DATETIME)") + successRead[OffsetDateTime](sql"SELECT c_timestamp FROM test LIMIT 1") - warnRead[Option[OffsetDateTime]](sql"SELECT '2019-02-13 22:03:21.051'") - warnRead[Option[OffsetDateTime]](sql"SELECT CAST('03:21' AS TIME)") - warnRead[Option[OffsetDateTime]](sql"SELECT CAST('2019-02-13' AS DATE)") - failedRead[Option[OffsetDateTime]](sql"SELECT 123") + warnRead[OffsetDateTime](sql"SELECT '2019-02-13 22:03:21.051'") + warnRead[OffsetDateTime](sql"SELECT c_date FROM test LIMIT 1") + warnRead[OffsetDateTime](sql"SELECT c_time FROM test LIMIT 1") + warnRead[OffsetDateTime](sql"SELECT c_datetime FROM test LIMIT 1") + failedRead[OffsetDateTime](sql"SELECT c_integer FROM test LIMIT 1") } test("LocalDateTime Read typechecks") { - successRead[Option[LocalDateTime]](sql"SELECT CAST('2019-02-13 22:03:21.051' AS DATETIME)") + successRead[LocalDateTime](sql"SELECT c_datetime FROM test LIMIT 1") - warnRead[Option[LocalDateTime]](sql"SELECT '2019-02-13 22:03:21.051'") - warnRead[Option[LocalDateTime]](sql"SELECT CAST('03:21' AS TIME)") - warnRead[Option[LocalDateTime]](sql"SELECT CAST('2019-02-13' AS DATE)") - failedRead[Option[LocalDateTime]](sql"SELECT 123") + warnRead[LocalDateTime](sql"SELECT '2019-02-13 22:03:21.051'") + warnRead[LocalDateTime](sql"SELECT c_date FROM test LIMIT 1") + warnRead[LocalDateTime](sql"SELECT c_time FROM test LIMIT 1") + warnRead[LocalDateTime](sql"SELECT c_timestamp FROM test LIMIT 1") + failedRead[LocalDateTime](sql"SELECT 123") } test("LocalDate Read typechecks") { - successRead[Option[LocalDate]](sql"SELECT CAST('2015-02-23' AS DATE)") + successRead[LocalDate](sql"SELECT c_date FROM test LIMIT 1") - warnRead[Option[LocalDate]](sql"SELECT CAST('2019-02-13 22:03:21.051' AS DATETIME)") - warnRead[Option[LocalDate]](sql"SELECT CAST('03:21' AS TIME)") - warnRead[Option[LocalDate]](sql"SELECT '2015-02-23'") - failedRead[Option[LocalDate]](sql"SELECT 123") + warnRead[LocalDate](sql"SELECT '2019-02-13'") + warnRead[LocalDate](sql"SELECT c_time FROM test LIMIT 1") + warnRead[LocalDate](sql"SELECT c_datetime FROM test LIMIT 1") + warnRead[LocalDate](sql"SELECT c_timestamp FROM test LIMIT 1") + failedRead[LocalDate](sql"SELECT 123") } test("LocalTime Read typechecks") { - successRead[Option[LocalTime]](sql"SELECT CAST('03:21' AS TIME)") + successRead[LocalTime](sql"SELECT c_time FROM test LIMIT 1") - warnRead[Option[LocalTime]](sql"SELECT CAST('2019-02-13 22:03:21.051' AS DATETIME)") - warnRead[Option[LocalTime]](sql"SELECT CAST('2015-02-23' AS DATE)") - failedRead[Option[LocalTime]](sql"SELECT '03:21'") - failedRead[Option[LocalTime]](sql"SELECT 123") + warnRead[LocalTime](sql"SELECT c_date FROM test LIMIT 1") + warnRead[LocalTime](sql"SELECT c_datetime FROM test LIMIT 1") + warnRead[LocalTime](sql"SELECT c_timestamp FROM test LIMIT 1") + warnRead[LocalTime](sql"SELECT '22:03:21'") + failedRead[LocalTime](sql"SELECT 123") } test("OffsetTime Read typechecks") { - successRead[Option[OffsetTime]](sql"SELECT CAST('2019-02-13 22:03:21.051' AS DATETIME)") + successRead[OffsetTime](sql"SELECT c_timestamp FROM test LIMIT 1") - warnRead[Option[OffsetTime]](sql"SELECT CAST('03:21' AS TIME)") - warnRead[Option[OffsetTime]](sql"SELECT CAST('2015-02-23' AS DATE)") - failedRead[Option[OffsetTime]](sql"SELECT '03:21'") - failedRead[Option[OffsetTime]](sql"SELECT 123") + warnRead[OffsetTime](sql"SELECT '22:03:21'") + warnRead[OffsetTime](sql"SELECT c_date FROM test LIMIT 1") + warnRead[OffsetTime](sql"SELECT c_time FROM test LIMIT 1") + warnRead[OffsetTime](sql"SELECT c_datetime FROM test LIMIT 1") + failedRead[OffsetTime](sql"SELECT 123") } private def successRead[A: Read](frag: Fragment): Unit = { diff --git a/modules/mysql/src/test/scala/doobie/mysql/MySQLTestTransactor.scala b/modules/mysql/src/test/scala/doobie/mysql/MySQLTestTransactor.scala index 46cad1ccf..f7018a8c2 100644 --- a/modules/mysql/src/test/scala/doobie/mysql/MySQLTestTransactor.scala +++ b/modules/mysql/src/test/scala/doobie/mysql/MySQLTestTransactor.scala @@ -11,7 +11,8 @@ object MySQLTestTransactor { val xa = Transactor.fromDriverManager[IO]( "com.mysql.cj.jdbc.Driver", - "jdbc:mysql://localhost:3306/world", + // args from solution 2a https://docs.oracle.com/cd/E17952_01/connector-j-8.0-en/connector-j-time-instants.html + "jdbc:mysql://localhost:3306/world?preserveInstants=true&connectionTimeZone=SERVER", "root", "password" ) } diff --git a/modules/mysql/src/test/scala/doobie/mysql/TypesSuite.scala b/modules/mysql/src/test/scala/doobie/mysql/TypesSuite.scala index 97b159f57..6d025c4f3 100644 --- a/modules/mysql/src/test/scala/doobie/mysql/TypesSuite.scala +++ b/modules/mysql/src/test/scala/doobie/mysql/TypesSuite.scala @@ -13,7 +13,6 @@ import doobie.mysql.implicits._ import doobie.mysql.util.arbitraries.SQLArbitraries._ import doobie.mysql.util.arbitraries.TimeArbitraries._ import org.scalacheck.Arbitrary -import org.scalacheck.Gen import org.scalacheck.Prop.forAll class TypesSuite extends munit.ScalaCheckSuite { @@ -33,15 +32,17 @@ class TypesSuite extends munit.ScalaCheckSuite { a0 <- Query0[Option[A]](s"SELECT value FROM test", None).unique } yield a0 - def testInOut[A](col: String)(implicit m: Get[A], p: Put[A], arbitrary: Arbitrary[A]) = { - testInOutWithCustomGen(col, arbitrary.arbitrary) + private def testInOut[A](col: String)(implicit m: Get[A], p: Put[A], arbitrary: Arbitrary[A]): Unit = { + testInOutCustomize(col ) } - def testInOutNormalize[A](col: String)(f: A => A)(implicit m: Get[A], p: Put[A], arbitrary: Arbitrary[A]) = { - testInOutWithCustomGen(col, arbitrary.arbitrary, skipNone = false, f) - } + private def testInOutCustomize[A]( + col: String, + skipNone: Boolean = false, + expected: A => A = identity[A](_) + )(implicit m: Get[A], p: Put[A], arbitrary: Arbitrary[A]): Unit = { + val gen = arbitrary.arbitrary - def testInOutWithCustomGen[A](col: String, gen: Gen[A], skipNone: Boolean = false, expected: A => A = identity[A](_))(implicit m: Get[A], p: Put[A]) = { test(s"Mapping for $col as ${m.typeStack} - write+read $col as ${m.typeStack}") { forAll(gen) { (t: A) => val actual = inOut(col, t).transact(xa).attempt.unsafeRunSync() @@ -61,19 +62,19 @@ class TypesSuite extends munit.ScalaCheckSuite { } } - def skip(col: String, msg: String = "not yet implemented") = - test(s"Mapping for $col ($msg)".ignore) {} - - testInOut[java.sql.Timestamp]("datetime(6)") - testInOutNormalize[java.time.OffsetDateTime]("datetime(6)")(_.withOffsetSameInstant(ZoneOffset.UTC)) - testInOut[java.time.Instant]("datetime(6)") - testInOut[java.time.LocalDateTime]("datetime(6)") - testInOutWithCustomGen[java.time.LocalDateTime]( + testInOutCustomize[java.time.OffsetDateTime]( "timestamp(6)", - arbitraryLocalDateTimeTimestamp.arbitrary, - skipNone = true // returns the current timestamp, lol + skipNone = true, // returns the current timestamp, lol + _.withOffsetSameInstant(ZoneOffset.UTC) ) + testInOutCustomize[java.time.Instant]( + "timestamp(6)", + skipNone = true, // returns the current timestamp, lol + ) + + testInOut[java.sql.Timestamp]("datetime(6)") + testInOut[java.time.LocalDateTime]("datetime(6)") testInOut[java.sql.Date]("date") testInOut[java.time.LocalDate]("date") diff --git a/modules/mysql/src/test/scala/doobie/mysql/util/arbitraries/TimeArbitraries.scala b/modules/mysql/src/test/scala/doobie/mysql/util/arbitraries/TimeArbitraries.scala index 362d6b5dc..72f0f2ce5 100644 --- a/modules/mysql/src/test/scala/doobie/mysql/util/arbitraries/TimeArbitraries.scala +++ b/modules/mysql/src/test/scala/doobie/mysql/util/arbitraries/TimeArbitraries.scala @@ -13,10 +13,6 @@ import org.scalacheck.Gen // https://dev.mysql.com/doc/refman/5.7/en/datetime.html object TimeArbitraries { - // plus and minus 2 days to avoid dealing with offsets - val MinDate = LocalDate.of(1000, 1, 3) - val MaxDate = LocalDate.of(9999, 12, 29) - // max resolution is 1 microsecond private def micros(nanos: Long) = Math.floorDiv(nanos, 1000) @@ -25,8 +21,9 @@ object TimeArbitraries { override def compare(x: LocalDate, y: LocalDate): Int = x compareTo y } + // 1000-01-01 to 9999-12-31 implicit val arbitraryLocalDate: Arbitrary[LocalDate] = Arbitrary { - GenHelpers.chooseT(MinDate, MaxDate, LocalDate.of(1970, 1, 1)) + GenHelpers.chooseT(LocalDate.of(1000, 1, 1), LocalDate.of(9999, 12, 31), LocalDate.of(1970, 1, 1)) } // 00:00:00.000000 to 23:59:59.999999 @@ -40,32 +37,25 @@ object TimeArbitraries { // '1000-01-01 00:00:00.000000' to '9999-12-31 23:59:59.999999' implicit val arbitraryLocalDateTime: Arbitrary[LocalDateTime] = Arbitrary { for { - date <- GenHelpers.chooseT(MinDate, MaxDate) + date <- arbitraryLocalDate.arbitrary time <- arbitraryLocalTime.arbitrary } yield LocalDateTime.of(date, time) } // '1970-01-01 00:00:01.000000' to '2038-01-19 03:14:07.999999 - val arbitraryLocalDateTimeTimestamp: Arbitrary[LocalDateTime] = Arbitrary { - val min = LocalDate.of(1970, 1, 2) // avoid not starting at 0 seconds on the 1st - val max = LocalDate.of(2038, 1, 18) // avoid ending at 3am on the 19th - - for { - date <- GenHelpers.chooseT(min, max) - time <- arbitraryLocalTime.arbitrary - } yield LocalDateTime.of(date, time) - } - - implicit val arbitraryInstant: Arbitrary[Instant] = Arbitrary { - arbitraryLocalDateTime.arbitrary.map(_.toInstant(ZoneOffset.UTC)) + val min = 1 * 1000000L + 0 + val max = 2147483647 * 1000000L + 999999 + + Gen.chooseNum(min, max).map { micros => + Instant.ofEpochSecond(micros / 1000000, micros % 1000000 * 1000) + } } implicit val arbitraryOffsetDateTime: Arbitrary[OffsetDateTime] = Arbitrary { for { - dateTime <- arbitraryLocalDateTime.arbitrary + instant <- arbitraryInstant.arbitrary offset <- Arbitrary.arbitrary[ZoneOffset] - } yield dateTime.atOffset(offset) + } yield instant.atOffset(offset) } - } diff --git a/modules/postgres/src/test/scala/doobie/postgres/LOSuite.scala b/modules/postgres/src/test/scala/doobie/postgres/LOSuite.scala index 84a188a25..7ce763574 100644 --- a/modules/postgres/src/test/scala/doobie/postgres/LOSuite.scala +++ b/modules/postgres/src/test/scala/doobie/postgres/LOSuite.scala @@ -14,7 +14,7 @@ class LOSuite extends munit.FunSuite with FileEquality { import PostgresTestTransactor.xa // A big file. Contents are irrelevant. - val in = new File("init/test-db.sql") + val in = new File("init/postgres/test-db.sql") test("large object support should allow round-trip from file to large object and back") { val out = File.createTempFile("doobie", "tst")