Skip to content

Commit

Permalink
#483 Fix count() queries with MS SQL and other SQL dialects.
Browse files Browse the repository at this point in the history
  • Loading branch information
yruslan committed Sep 3, 2024
1 parent b13dd12 commit 2d4083f
Show file tree
Hide file tree
Showing 15 changed files with 156 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ trait SqlGenerator {
*/
def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String

/**
* Generates a query that returns the record count of an SQL query that is already formed.
*/
def getCountQueryForSql(filteredSql: String): String

/**
* Generates a query that returns data of a table that does not have the information date field.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,18 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
)
}

private[core] def getCountForSql(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): Long = {
private[core] def getCountSqlQuery(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = {
val filteredSql = TableReaderJdbcNative.getFilteredSql(sql, infoDateBegin, infoDateEnd)
val countSql = s"SELECT COUNT(*) FROM ($filteredSql)"

sqlGen.getCountQueryForSql(filteredSql)
}

private[core] def getCountForSql(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): Long = {
val countSql = getCountSqlQuery(sql, infoDateBegin, infoDateEnd)
var count = 0L

log.info(s"Executing: $countSql")

JdbcNativeUtils.withResultSet(jdbcUrlSelector, countSql, jdbcRetries) { rs =>
if (!rs.next())
throw new IllegalStateException(s"No rows returned by the count query: $countSql")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorDb2(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class SqlGeneratorDenodo(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorGeneric(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConf
s"SELECT COUNT(*) AS CNT FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class SqlGeneratorHive(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorHsqlDb(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql)"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class SqlGeneratorMicrosoft(sqlConfig: SqlConfig) extends SqlGenerator {
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WITH (NOLOCK) WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${getLimit(limit)}${columnExpr(columns)} FROM ${escape(tableName)} WITH (NOLOCK)"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorMySQL(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorOracle(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit, hasWhere = false)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit, hasWhere = true)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorPostgreSQL(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlC
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class SqlGeneratorSas(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT ${getColumnSql(tableName, columns)} FROM ${escape(tableName)}${getLimit(limit, hasWhere = false)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${getColumnSql(tableName, columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit, hasWhere = true)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SqlGeneratorDummy(sqlConfig: SqlConfig) extends SqlGenerator {

override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = null

override def getCountQueryForSql(filteredSql: String): String = null

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = null

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,46 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
}
}

"getCountSqlQuery" should {
"return a count query for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM COMPANY)")
}

"return a count query for a table event-like query" in {
val testConfig = conf.getConfig("reader")
.withValue("has.information.date.column", ConfigValueFactory.fromAnyRef(true))
.withValue("information.date.column", ConfigValueFactory.fromAnyRef("info_date"))
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY WHERE info_date BETWEEN '@dateFrom' AND '@dateTo'", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM COMPANY WHERE info_date BETWEEN '2022-02-18' AND '2022-02-18')")
}

"return a count query for a complex event-like query" in {
val testConfig = conf.getConfig("reader")
.withValue("jdbc.driver", ConfigValueFactory.fromAnyRef("net.sourceforge.jtds.jdbc.Driver"))
.withValue("has.information.date.column", ConfigValueFactory.fromAnyRef(true))
.withValue("information.date.column", ConfigValueFactory.fromAnyRef("info_date"))
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM my_db.my_table WHERE info_date = CAST(REPLACE(CAST(CAST('@infoDate' AS DATE) AS VARCHAR(10)), '-', '') AS INTEGER)", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM my_db.my_table WHERE info_date = CAST(REPLACE(CAST(CAST('2022-02-18' AS DATE) AS VARCHAR(10)), '-', '') AS INTEGER)) AS query")
}
}

"getData()" should {
"return data for a table snapshot-like query" in {
val testConfig = conf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -311,6 +317,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -540,6 +552,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -740,6 +758,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -856,6 +880,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "(A) tbl")
Expand Down Expand Up @@ -972,6 +1002,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1088,6 +1124,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1204,6 +1246,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1320,6 +1368,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B)")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1448,6 +1502,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down

0 comments on commit 2d4083f

Please sign in to comment.