Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More precise typechecking using vendor types #2023

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
version: '3.1'

services:

postgres:
image: postgis/postgis:11-3.3
image: postgis/postgis:16-3.4
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: world
ports:
- 5432:5432
volumes:
- ./init/:/docker-entrypoint-initdb.d/
- ./init/postgres/:/docker-entrypoint-initdb.d/
deploy:
resources:
limits:
memory: 500M


mysql:
image: mysql:8.0-debian
environment:
MYSQL_ROOT_PASSWORD: password
MYSQL_DATABASE: world
ports:
- 3306:3306
volumes:
- ./init/mysql/:/docker-entrypoint-initdb.d/
deploy:
resources:
limits:
memory: 500M
11 changes: 11 additions & 0 deletions init/mysql/test-table.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

CREATE TABLE IF NOT EXISTS test (
c_integer INTEGER NOT NULL,
c_varchar VARCHAR(1024) NOT NULL,
c_date DATE NOT NULL,
c_datetime DATETIME(6) NOT NULL,
c_time TIME(6) NOT NULL,
c_timestamp TIMESTAMP(6) NOT NULL
);
INSERT INTO test(c_integer, c_varchar, c_date, c_datetime, c_time, c_timestamp)
VALUES (123, 'str', '2019-02-13', '2019-02-13 22:03:21.051', '22:03:21.051', '2019-02-13 22:03:21.051');
File renamed without changes.
51 changes: 29 additions & 22 deletions modules/core/src/main/scala/doobie/hi/connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,28 @@

package doobie.hi

import doobie.util.compat.propertiesToScala
import cats.Foldable
import cats.data.Ior
import cats.effect.kernel.syntax.monadCancel._
import cats.syntax.all._
import doobie.enumerated.AutoGeneratedKeys
import doobie.enumerated.Holdability
import doobie.enumerated.ResultSetType
import doobie.enumerated.Nullability
import doobie.enumerated.ResultSetConcurrency
import doobie.enumerated.ResultSetType
import doobie.enumerated.TransactionIsolation
import doobie.enumerated.AutoGeneratedKeys
import doobie.util.{ Read, Write }
import doobie.util.analysis.Analysis
import doobie.util.analysis.ColumnMeta
import doobie.util.analysis.ParameterMeta
import doobie.util.compat.propertiesToScala
import doobie.util.stream.repeatEvalChunks
import doobie.util.{ Get, Put, Read, Write }
import fs2.Stream
import fs2.Stream.{ eval, bracket }

import java.sql.{ Savepoint, PreparedStatement, ResultSet }

import scala.collection.immutable.Map

import cats.Foldable
import cats.syntax.all._
import cats.effect.kernel.syntax.monadCancel._
import fs2.Stream
import fs2.Stream.{ eval, bracket }

/**
* Module of high-level constructors for `ConnectionIO` actions.
* @group Modules
Expand Down Expand Up @@ -92,24 +94,29 @@ object connection {
* readable resultset row type `B`.
*/
def prepareQueryAnalysis[A: Write, B: Read](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
(HPS.getParameterMappings[A], HPS.getColumnMappings[B]) mapN (Analysis(sql, _, _))
}
prepareAnalysis(sql, HPS.getParameterMappings[A], HPS.getColumnMappings[B])

def prepareQueryAnalysis0[B: Read](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
HPS.getColumnMappings[B] map (cm => Analysis(sql, Nil, cm))
}
prepareAnalysis(sql, FPS.pure(Nil), HPS.getColumnMappings[B])

def prepareUpdateAnalysis[A: Write](sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
HPS.getParameterMappings[A] map (pm => Analysis(sql, pm, Nil))
}
prepareAnalysis(sql, HPS.getParameterMappings[A], FPS.pure(Nil))

def prepareUpdateAnalysis0(sql: String): ConnectionIO[Analysis] =
prepareStatement(sql) {
Analysis(sql, Nil, Nil).pure[PreparedStatementIO]
prepareAnalysis(sql, FPS.pure(Nil), FPS.pure(Nil))

private def prepareAnalysis(
sql: String,
params: PreparedStatementIO[List[(Put[_], Nullability.NullabilityKnown) Ior ParameterMeta]],
columns: PreparedStatementIO[List[(Get[_], Nullability.NullabilityKnown) Ior ColumnMeta]],
) = {
val mappings = prepareStatement(sql) {
(params, columns).tupled
}
(HC.getMetaData(FDMD.getDriverName), mappings).mapN { case (driver, (p, c)) =>
Analysis(driver, sql, p, c)
}
}


/** @group Statements */
Expand Down
16 changes: 16 additions & 0 deletions modules/core/src/main/scala/doobie/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

import doobie.util.meta.{LegacyMeta, TimeMetaInstances}
// Copyright (c) 2013-2020 Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

/**
* Top-level import, providing aliases for the most commonly used types and modules from
* doobie-free and doobie-core. A typical starting set of imports would be something like this.
Expand All @@ -21,13 +26,24 @@ package object doobie
object implicits
extends free.Instances
with generic.AutoDerivation
with LegacyMeta
with syntax.AllSyntax {

// re-export these instances so `Meta` takes priority, must be in the object
implicit def metaProjectionGet[A](implicit m: Meta[A]): Get[A] = Get.metaProjection
implicit def metaProjectionPut[A](implicit m: Meta[A]): Put[A] = Put.metaProjectionWrite
implicit def fromGetRead[A](implicit G: Get[A]): Read[A] = Read.fromGet
implicit def fromPutWrite[A](implicit P: Put[A]): Write[A] = Write.fromPut

/**
* Only use this import if:
* 1. You're NOT using one of the database doobie has direct java.time isntances for
* (PostgreSQL / MySQL). (They have more accurate column type checks)
* 2. Your driver natively supports java.time.* types
*
* If your driver doesn't support java.time.* types, use [[doobie.implicits.legacy.instant/localdate]] instead
*/
object javatimedrivernative extends TimeMetaInstances
}

}
53 changes: 22 additions & 31 deletions modules/core/src/main/scala/doobie/util/analysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,9 @@ object analysis {

/** Metadata for the JDBC end of a column/parameter mapping. */
final case class ColumnMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String)
object ColumnMeta {
def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, name: String): ColumnMeta = {
new ColumnMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, name)
}
}

/** Metadata for the JDBC end of a column/parameter mapping. */
final case class ParameterMeta(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode)
object ParameterMeta {
def apply(jdbcType: JdbcType, vendorTypeName: String, nullability: Nullability, mode: ParameterMode): ParameterMeta = {
new ParameterMeta(tweakJdbcType(jdbcType, vendorTypeName), vendorTypeName, nullability, mode)
}
}

private def tweakJdbcType(jdbcType: JdbcType, vendorTypeName: String) = jdbcType match {
// the Postgres driver does not return *WithTimezone types but they are pretty much required for proper analysis
// https://github.com/pgjdbc/pgjdbc/issues/2485
// https://github.com/pgjdbc/pgjdbc/issues/1766
case JdbcType.Time if vendorTypeName.compareToIgnoreCase("timetz") == 0 => JdbcType.TimeWithTimezone
case JdbcType.Timestamp if vendorTypeName.compareToIgnoreCase("timestamptz") == 0 => JdbcType.TimestampWithTimezone
case t => t
}

sealed trait AlignmentError extends Product with Serializable {
def tag: String
Expand Down Expand Up @@ -100,10 +81,10 @@ object analysis {
override val tag = "C"
override def msg =
s"""|${schema.jdbcType.show.toUpperCase} (${schema.vendorTypeName}) is not
|coercible to ${typeName(get.typeStack.last, n)} according to the JDBC specification or any defined
|coercible to ${typeName(get.typeStack.last, n)} (${get.vendorTypeNames.mkString(",")}) according to the JDBC specification or any defined
|mapping.
|Fix this by changing the schema type to
|${get.jdbcSources.toList.map(_.show.toUpperCase).toList.mkString(" or ") }; or the
|${get.jdbcSources.toList.map(_.show.toUpperCase).mkString(" or ") }; or the
|Scala type to an appropriate ${if (schema.jdbcType === JdbcType.Array) "array" else "object"}
|type.
|""".stripMargin.linesIterator.mkString(" ")
Expand All @@ -122,34 +103,46 @@ object analysis {

/** Compatibility analysis for the given statement and aligned mappings. */
final case class Analysis(
driver: String,
sql: String,
parameterAlignment: List[(Put[_], NullabilityKnown) Ior ParameterMeta],
columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta]) {
columnAlignment: List[(Get[_], NullabilityKnown) Ior ColumnMeta]
) {

def parameterMisalignments: List[ParameterMisalignment] =
parameterAlignment.zipWithIndex.collect {
case (Ior.Left(_), n) => ParameterMisalignment(n + 1, None)
case (Ior.Right(p), n) => ParameterMisalignment(n + 1, Some(p))
}

private def hasParameterTypeErrors[A](put: Put[A], paramMeta: ParameterMeta): Boolean = {
val jdbcTypeMatches = put.jdbcTargets.contains_(paramMeta.jdbcType)
val vendorTypeMatches = put.vendorTypeNames.isEmpty || put.vendorTypeNames.contains_(paramMeta.vendorTypeName)

!jdbcTypeMatches || !vendorTypeMatches
}

def parameterTypeErrors: List[ParameterTypeError] =
parameterAlignment.zipWithIndex.collect {
case (Ior.Both((j, n1), p), n) if !j.jdbcTargets.contains_(p.jdbcType) =>
ParameterTypeError(n + 1, j, n1, p.jdbcType, p.vendorTypeName)
case (Ior.Both((put, n1), paramMeta), n) if hasParameterTypeErrors(put, paramMeta)=>
ParameterTypeError(n + 1, put, n1, paramMeta.jdbcType, paramMeta.vendorTypeName)
}

def columnMisalignments: List[ColumnMisalignment] =
columnAlignment.zipWithIndex.collect {
case (Ior.Left(j), n) => ColumnMisalignment(n + 1, Left(j))
case (Ior.Right(p), n) => ColumnMisalignment(n + 1, Right(p))
}


private def hasColumnTypeError[A](get: Get[A], columnMeta: ColumnMeta): Boolean = {
val jdbcTypeMatches = (get.jdbcSources.toList ++ get.jdbcSourceSecondary).contains_(columnMeta.jdbcType)
val vendorTypeMatches = get.vendorTypeNames.isEmpty || get.vendorTypeNames.contains_(columnMeta.vendorTypeName)
!jdbcTypeMatches || !vendorTypeMatches
}
def columnTypeErrors: List[ColumnTypeError] =
columnAlignment.zipWithIndex.collect {
case (Ior.Both((j, n1), p), n) if !(j.jdbcSources.toList ++ j.jdbcSourceSecondary).contains_(p.jdbcType) =>
ColumnTypeError(n + 1, j, n1, p)
case (Ior.Both((j, n1), p), n) if (p.jdbcType === JdbcType.JavaObject || p.jdbcType === JdbcType.Other) && !j.schemaTypes.headOption.contains_(p.vendorTypeName) =>
ColumnTypeError(n + 1, j, n1, p)
case (Ior.Both((get, n1), p), n) if hasColumnTypeError(get, p) =>
ColumnTypeError(n + 1, get, n1, p)
}

def columnTypeWarnings: List[ColumnTypeWarning] =
Expand Down Expand Up @@ -224,6 +217,4 @@ object analysis {
case Nullable => "NULL"
case NullableUnknown => "NULL?"
}


}
Loading