Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#483 Fix count() queries with MS SQL and other SQL dialects. #484

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ trait SqlGenerator {
*/
def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String

/**
* Generates a query that returns the record count of an SQL query that is already formed.
*/
def getCountQueryForSql(filteredSql: String): String

/**
* Generates a query that returns data of a table that does not have the information date field.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,18 @@ class TableReaderJdbc(jdbcReaderConfig: TableReaderJdbcConfig,
)
}

private[core] def getCountForSql(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): Long = {
private[core] def getCountSqlQuery(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = {
val filteredSql = TableReaderJdbcNative.getFilteredSql(sql, infoDateBegin, infoDateEnd)
val countSql = s"SELECT COUNT(*) FROM ($filteredSql)"

sqlGen.getCountQueryForSql(filteredSql)
}

private[core] def getCountForSql(sql: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): Long = {
val countSql = getCountSqlQuery(sql, infoDateBegin, infoDateEnd)
var count = 0L

log.info(s"Executing: $countSql")

JdbcNativeUtils.withResultSet(jdbcUrlSelector, countSql, jdbcRetries) { rs =>
if (!rs.next())
throw new IllegalStateException(s"No rows returned by the count query: $countSql")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorDb2(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class SqlGeneratorDenodo(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorGeneric(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConf
s"SELECT COUNT(*) AS CNT FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class SqlGeneratorHive(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT COUNT(*) FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class SqlGeneratorHsqlDb(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql)"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class SqlGeneratorMicrosoft(sqlConfig: SqlConfig) extends SqlGenerator {
s"SELECT ${getAliasExpression("COUNT(*)", "CNT")} FROM ${escape(tableName)} WITH (NOLOCK) WHERE $where"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = {
s"SELECT ${getLimit(limit)}${columnExpr(columns)} FROM ${escape(tableName)} WITH (NOLOCK)"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorMySQL(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorOracle(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfi
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit, hasWhere = false)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit, hasWhere = true)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SqlGeneratorPostgreSQL(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlC
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)}${getLimit(limit)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) query"
}

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${columnExpr(columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class SqlGeneratorSas(sqlConfig: SqlConfig) extends SqlGeneratorBase(sqlConfig)
s"SELECT ${getColumnSql(tableName, columns)} FROM ${escape(tableName)}${getLimit(limit, hasWhere = false)}"
}

override def getCountQueryForSql(filteredSql: String): String = {
s"SELECT COUNT(*) FROM ($filteredSql) AS query"
}

def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = {
val where = getWhere(infoDateBegin, infoDateEnd)
s"SELECT ${getColumnSql(tableName, columns)} FROM ${escape(tableName)} WHERE $where${getLimit(limit, hasWhere = true)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class SqlGeneratorDummy(sqlConfig: SqlConfig) extends SqlGenerator {

override def getCountQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate): String = null

override def getCountQueryForSql(filteredSql: String): String = null

override def getDataQuery(tableName: String, columns: Seq[String], limit: Option[Int]): String = null

override def getDataQuery(tableName: String, infoDateBegin: LocalDate, infoDateEnd: LocalDate, columns: Seq[String], limit: Option[Int]): String = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,46 @@ class TableReaderJdbcSuite extends AnyWordSpec with BeforeAndAfterAll with Spark
}
}

"getCountSqlQuery" should {
"return a count query for a table snapshot-like query" in {
val testConfig = conf
val reader = TableReaderJdbc(testConfig.getConfig("reader"), "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM COMPANY)")
}

"return a count query for a table event-like query" in {
val testConfig = conf.getConfig("reader")
.withValue("has.information.date.column", ConfigValueFactory.fromAnyRef(true))
.withValue("information.date.column", ConfigValueFactory.fromAnyRef("info_date"))
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM COMPANY WHERE info_date BETWEEN '@dateFrom' AND '@dateTo'", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM COMPANY WHERE info_date BETWEEN '2022-02-18' AND '2022-02-18')")
}

"return a count query for a complex event-like query" in {
val testConfig = conf.getConfig("reader")
.withValue("jdbc.driver", ConfigValueFactory.fromAnyRef("net.sourceforge.jtds.jdbc.Driver"))
.withValue("has.information.date.column", ConfigValueFactory.fromAnyRef(true))
.withValue("information.date.column", ConfigValueFactory.fromAnyRef("info_date"))
.withValue("information.date.type", ConfigValueFactory.fromAnyRef("string"))
.withValue("information.date.format", ConfigValueFactory.fromAnyRef("yyyy-MM-dd"))

val reader = TableReaderJdbc(testConfig, "reader")

val sql = reader.getCountSqlQuery("SELECT * FROM my_db.my_table WHERE info_date = CAST(REPLACE(CAST(CAST('@infoDate' AS DATE) AS VARCHAR(10)), '-', '') AS INTEGER)", infoDate, infoDate)

assert(sql == "SELECT COUNT(*) FROM (SELECT * FROM my_db.my_table WHERE info_date = CAST(REPLACE(CAST(CAST('2022-02-18' AS DATE) AS VARCHAR(10)), '-', '') AS INTEGER)) AS query")
}
}

"getData()" should {
"return data for a table snapshot-like query" in {
val testConfig = conf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -311,6 +317,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -540,6 +552,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -740,6 +758,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "A")
Expand Down Expand Up @@ -856,6 +880,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(gen.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(gen.getDtable("A") == "(A) tbl")
Expand Down Expand Up @@ -972,6 +1002,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1088,6 +1124,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1204,6 +1246,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1320,6 +1368,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B)")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down Expand Up @@ -1448,6 +1502,12 @@ class SqlGeneratorLoaderSuite extends AnyWordSpec with RelationalDbFixture {
}
}

"getCountQueryForSql" should {
"generate count queries for an SQL subquery" in {
assert(genDate.getCountQueryForSql("SELECT A FROM B") == "SELECT COUNT(*) FROM (SELECT A FROM B) AS query")
}
}

"getDtable" should {
"return the original table when a table is provided" in {
assert(genDate.getDtable("A") == "A")
Expand Down
Loading