Skip to content

Commit

Permalink
KE-43358 logical view supports database
Browse files Browse the repository at this point in the history
  • Loading branch information
RolatZhang committed Jan 16, 2024
1 parent c43b45f commit def087c
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,5 @@ object GlobalTempView extends ViewType
* while the local temporary view exists, unless the view name is qualified by database.
*/
object PersistedView extends ViewType

object LogicalView extends ViewType
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,45 @@ class GlobalTempViewManager(val database: String) {

/** List of view definitions, mapping from view name to logical plan. */
@GuardedBy("this")
private val viewDefinitions = new mutable.HashMap[String, TemporaryViewRelation]
private val databaseViewDefinitions = new mutable.HashMap[String, TemporaryViewDefinition]

def isTempDatabase(database: String): Boolean = {
databaseViewDefinitions.contains(database)
}

private def getViewDefinitions(db: String): mutable.HashMap[String, TemporaryViewRelation] = {
val key = if (db == null) database else db
databaseViewDefinitions.getOrElseUpdate(key, TemporaryViewDefinition()).viewDefinitions
}

/**
* Returns the global view definition which matches the given name, or None if not found.
*/
def get(name: String): Option[TemporaryViewRelation] = synchronized {
viewDefinitions.get(name)
def get(name: String): Option[TemporaryViewRelation] = {
get(null, name)
}

def get(db: String, name: String): Option[TemporaryViewRelation] = synchronized {
getViewDefinitions(db).get(name)
}

/**
* Creates a global temp view, or issue an exception if the view already exists and
* `overrideIfExists` is false.
*/
def create(
name: String,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = {
create(null, name, viewDefinition, overrideIfExists)
}

def create(
db: String,
name: String,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = synchronized {
val viewDefinitions = getViewDefinitions(db)
if (!overrideIfExists && viewDefinitions.contains(name)) {
throw new TempTableAlreadyExistsException(name)
}
Expand All @@ -66,8 +88,16 @@ class GlobalTempViewManager(val database: String) {
* Updates the global temp view if it exists, returns true if updated, false otherwise.
*/
def update(
name: String,
viewDefinition: TemporaryViewRelation): Boolean = {
update(null, name, viewDefinition)
}

def update(
db: String,
name: String,
viewDefinition: TemporaryViewRelation): Boolean = synchronized {
val viewDefinitions = getViewDefinitions(db)
if (viewDefinitions.contains(name)) {
viewDefinitions.put(name, viewDefinition)
true
Expand All @@ -79,16 +109,25 @@ class GlobalTempViewManager(val database: String) {
/**
* Removes the global temp view if it exists, returns true if removed, false otherwise.
*/
def remove(name: String): Boolean = synchronized {
viewDefinitions.remove(name).isDefined
def remove(name: String): Boolean = {
remove(null, name)
}

def remove(db: String, name: String): Boolean = synchronized {
getViewDefinitions(db).remove(name).isDefined
}

/**
* Renames the global temp view if the source view exists and the destination view not exists, or
* issue an exception if the source view exists but the destination view already exists. Returns
* true if renamed, false otherwise.
*/
def rename(oldName: String, newName: String): Boolean = synchronized {
def rename(oldName: String, newName: String): Boolean = {
rename(null, oldName, newName)
}

def rename(db: String, oldName: String, newName: String): Boolean = synchronized {
val viewDefinitions = getViewDefinitions(db)
if (viewDefinitions.contains(oldName)) {
if (viewDefinitions.contains(newName)) {
throw QueryCompilationErrors.renameTempViewToExistingViewError(oldName, newName)
Expand All @@ -106,14 +145,27 @@ class GlobalTempViewManager(val database: String) {
/**
* Lists the names of all global temporary views.
*/
def listViewNames(pattern: String): Seq[String] = synchronized {
def listViewNames(pattern: String): Seq[String] = {
listViewNames(null, pattern)
}

def listViewNames(db: String, pattern: String): Seq[String] = synchronized {
val viewDefinitions = getViewDefinitions(db)
StringUtils.filterPattern(viewDefinitions.keys.toSeq, pattern)
}

def listDBNames(pattern: String): Seq[String] = synchronized {
StringUtils.filterPattern(databaseViewDefinitions.keys.toSeq, pattern)
}

/**
* Clears all the global temporary views.
*/
def clear(): Unit = synchronized {
viewDefinitions.clear()
databaseViewDefinitions.clear()
}
}

case class TemporaryViewDefinition() {
val viewDefinitions = new mutable.HashMap[String, TemporaryViewRelation]
}
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,18 @@ class SessionCatalog(
globalTempViewManager.create(formatTableName(name), viewDefinition, overrideIfExists)
}

/**
* Create a logical view.
*/
def createLogicalView(
db: String,
name: String,
viewDefinition: TemporaryViewRelation,
overrideIfExists: Boolean): Unit = {
globalTempViewManager.create(formatTableName(db),
formatTableName(name), viewDefinition, overrideIfExists)
}

/**
* Alter the definition of a local/global temp view matching the given name, returns true if a
* temp view is matched and altered, false otherwise.
Expand Down Expand Up @@ -672,6 +684,10 @@ class SessionCatalog(
getRawGlobalTempView(name).map(getTempViewPlan)
}

def getLogicalView(db: String, name: String): Option[View] = {
globalTempViewManager.get(db, formatTableName(name)).map(getTempViewPlan)
}

/**
* Drop a local temporary view.
*
Expand Down Expand Up @@ -707,11 +723,14 @@ class SessionCatalog(
val table = formatTableName(name.table)
if (name.database.isEmpty) {
tempViews.get(table).map(_.tableMeta).getOrElse(getTableMetadata(name))
} else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) {
globalTempViewManager.get(table).map(_.tableMeta)
.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table))
} else {
getTableMetadata(name)
val db = formatDatabaseName(name.database.get)
if (globalTempViewManager.isTempDatabase(db)) {
globalTempViewManager.get(db, table).map(_.tableMeta)
.getOrElse(throw new NoSuchTableException(globalTempViewManager.database, table))
} else {
getTableMetadata(name)
}
}
}

Expand Down Expand Up @@ -772,8 +791,8 @@ class SessionCatalog(
purge: Boolean): Unit = synchronized {
val db = formatDatabaseName(name.database.getOrElse(currentDb))
val table = formatTableName(name.table)
if (db == globalTempViewManager.database) {
val viewExists = globalTempViewManager.remove(table)
if (globalTempViewManager.isTempDatabase(db)) {
val viewExists = globalTempViewManager.remove(db, table)
if (!viewExists && !ignoreIfNotExists) {
throw new NoSuchTableException(globalTempViewManager.database, table)
}
Expand Down Expand Up @@ -813,8 +832,8 @@ class SessionCatalog(
synchronized {
val db = formatDatabaseName(name.database.getOrElse(currentDb))
val table = formatTableName(name.table)
if (db == globalTempViewManager.database) {
globalTempViewManager.get(table).map { viewDef =>
if (globalTempViewManager.isTempDatabase(db)) {
globalTempViewManager.get(db, table).map { viewDef =>
SubqueryAlias(table, db, getTempViewPlan(viewDef))
}.getOrElse(throw new NoSuchTableException(db, table))
} else if (name.database.isDefined || !tempViews.contains(table)) {
Expand Down Expand Up @@ -950,9 +969,9 @@ class SessionCatalog(

def lookupGlobalTempView(db: String, table: String): Option[SubqueryAlias] = {
val formattedDB = formatDatabaseName(db)
if (formattedDB == globalTempViewManager.database) {
if (globalTempViewManager.isTempDatabase(formattedDB)) {
val formattedTable = formatTableName(table)
getGlobalTempView(formattedTable).map { view =>
getLogicalView(formattedDB, formattedTable).map { view =>
SubqueryAlias(formattedTable, formattedDB, view)
}
} else {
Expand All @@ -973,10 +992,13 @@ class SessionCatalog(
val tableName = formatTableName(name.table)
if (name.database.isEmpty) {
tempViews.get(tableName).map(getTempViewPlan)
} else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) {
globalTempViewManager.get(tableName).map(getTempViewPlan)
} else {
None
val db = formatDatabaseName(name.database.get)
if (globalTempViewManager.isTempDatabase(db)) {
globalTempViewManager.get(db, tableName).map(getTempViewPlan)
} else {
None
}
}
}

Expand Down Expand Up @@ -1032,9 +1054,9 @@ class SessionCatalog(
pattern: String,
includeLocalTempViews: Boolean): Seq[TableIdentifier] = {
val dbName = formatDatabaseName(db)
val dbTables = if (dbName == globalTempViewManager.database) {
globalTempViewManager.listViewNames(pattern).map { name =>
TableIdentifier(name, Some(globalTempViewManager.database))
val dbTables = if (globalTempViewManager.isTempDatabase(dbName)) {
globalTempViewManager.listViewNames(dbName, pattern).map { name =>
TableIdentifier(name, Some(dbName))
}
} else {
requireDbExists(dbName)
Expand All @@ -1055,9 +1077,9 @@ class SessionCatalog(
*/
def listViews(db: String, pattern: String): Seq[TableIdentifier] = {
val dbName = formatDatabaseName(db)
val dbViews = if (dbName == globalTempViewManager.database) {
globalTempViewManager.listViewNames(pattern).map { name =>
TableIdentifier(name, Some(globalTempViewManager.database))
val dbViews = if (globalTempViewManager.isTempDatabase(dbName)) {
globalTempViewManager.listViewNames(dbName, pattern).map { name =>
TableIdentifier(name, Some(dbName))
}
} else {
requireDbExists(dbName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.TerminalNode

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, UnresolvedDBObjectName, UnresolvedFunc}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, LogicalView, PersistedView, UnresolvedDBObjectName, UnresolvedFunc}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser._
Expand Down Expand Up @@ -550,7 +550,7 @@ class SparkSqlAstBuilder extends AstBuilder {
plan(ctx.query),
false,
false,
GlobalTempView)
LogicalView)

}

Expand Down Expand Up @@ -582,7 +582,7 @@ class SparkSqlAstBuilder extends AstBuilder {
plan(ctx.query),
false,
true,
GlobalTempView)
LogicalView)

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, GlobalTempView, LocalTempView, ViewType}
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, GlobalTempView, LocalTempView, LogicalView, ViewType}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, TemporaryViewRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, LogicalPlan, Project, View}
Expand Down Expand Up @@ -122,8 +122,12 @@ case class CreateViewCommand(
aliasedPlan,
referredTempFunctions)
catalog.createTempView(name.table, tableDefinition, overrideIfExists = replace)
} else if (viewType == GlobalTempView) {
val db = sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
} else if (viewType == GlobalTempView || viewType==LogicalView) {
val db = if (name.database.isDefined && viewType == LogicalView) {
name.database.get
} else {
sparkSession.sessionState.conf.getConf(StaticSQLConf.GLOBAL_TEMP_DATABASE)
}
val viewIdent = TableIdentifier(name.table, Option(db))
val aliasedPlan = aliasPlan(sparkSession, analyzedPlan)
val tableDefinition = createTemporaryViewRelation(
Expand All @@ -135,7 +139,7 @@ case class CreateViewCommand(
analyzedPlan,
aliasedPlan,
referredTempFunctions)
catalog.createGlobalTempView(name.table, tableDefinition, overrideIfExists = replace)
catalog.createLogicalView(db, name.table, tableDefinition, overrideIfExists = replace)
} else if (catalog.tableExists(name)) {
val tableMetadata = catalog.getTableMetadata(name)
if (allowExisting) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,18 @@ class GlobalTempViewSuite extends QueryTest with SharedSparkSession {
}
}

test("logical VIEW") {
sql("CREATE logical VIEW logicalDb1.table AS SELECT 1")
checkAnswer(spark.table("logicalDb1.table"), Row(1))
sql("drop logical VIEW logicalDb1.table")
sql("CREATE logical VIEW logicalDb1.table AS SELECT 1, 2")
checkAnswer(spark.table("logicalDb1.table"), Row(1, 2))
sql("drop logical VIEW logicalDb1.table")
sql("CREATE logical VIEW logicalDb2.table AS SELECT 1, 2, 3")
checkAnswer(spark.table("logicalDb2.table"), Row(1, 2, 3))
sql("drop logical VIEW logicalDb2.table")
}

test("global temp view database should be preserved") {
val e = intercept[AnalysisException](sql(s"CREATE DATABASE $globalTempDB"))
assert(e.message.contains("system preserved database"))
Expand Down

0 comments on commit def087c

Please sign in to comment.