Skip to content

Commit

Permalink
Refactor CatalogFunction to use FunctionIdentifier
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Or committed Mar 16, 2016
1 parent dd1fbae commit 2118212
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ private[sql] object TableIdentifier {
* If `database` is not defined, the current database is used.
*/
// TODO: reuse some code with TableIdentifier.
private[sql] case class FunctionIdentifier(name: String, database: Option[String]) {
private[sql] case class FunctionIdentifier(funcName: String, database: Option[String]) {
def this(name: String) = this(name, None)

override def toString: String = quotedString

def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`")
def quotedString: String = database.map(db => s"`$db`.`$funcName`").getOrElse(s"`$funcName`")

def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name)
def unquotedString: String = database.map(db => s"$db.$funcName").getOrElse(funcName)
}

private[sql] object FunctionIdentifier {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}


/**
Expand Down Expand Up @@ -294,10 +294,10 @@ class InMemoryCatalog extends ExternalCatalog {

override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
if (existsFunction(db, func.name)) {
if (existsFunction(db, func.name.funcName)) {
throw new AnalysisException(s"Function $func already exists in $db database")
} else {
catalog(db).functions.put(func.name, func)
catalog(db).functions.put(func.name.funcName, func)
}
}

Expand All @@ -308,14 +308,14 @@ class InMemoryCatalog extends ExternalCatalog {

override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
requireFunctionExists(db, oldName)
val newFunc = getFunction(db, oldName).copy(name = newName)
val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
catalog(db).functions.remove(oldName)
catalog(db).functions.put(newName, newFunc)
}

override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized {
requireFunctionExists(db, funcDefinition.name)
catalog(db).functions.put(funcDefinition.name, funcDefinition)
requireFunctionExists(db, funcDefinition.name.funcName)
catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition)
}

override def getFunction(db: String, funcName: String): CatalogFunction = synchronized {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ abstract class ExternalCatalog {
* @param name name of the function
* @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
*/
// TODO: use FunctionIdentifier here.
case class CatalogFunction(name: String, className: String)
case class CatalogFunction(name: FunctionIdentifier, className: String)


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}


/**
Expand Down Expand Up @@ -82,7 +82,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
catalog
}

private def newFunc(): CatalogFunction = CatalogFunction("funcname", funcClass)
private def newFunc(): CatalogFunction = newFunc("funcName")

private def newDb(name: String): CatalogDatabase = {
CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
Expand All @@ -97,7 +97,9 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")))
}

private def newFunc(name: String): CatalogFunction = CatalogFunction(name, funcClass)
private def newFunc(name: String): CatalogFunction = {
CatalogFunction(FunctionIdentifier(name, database = None), funcClass)
}

/**
* Whether the catalog's table partitions equal the ones given.
Expand Down Expand Up @@ -498,7 +500,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {

test("get function") {
val catalog = newBasicCatalog()
assert(catalog.getFunction("db2", "func1") == newFunc("func1"))
assert(catalog.getFunction("db2", "func1") ==
CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass))
intercept[AnalysisException] {
catalog.getFunction("db2", "does_not_exist")
}
Expand All @@ -517,7 +520,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
assert(catalog.getFunction("db2", "func1").className == funcClass)
catalog.renameFunction("db2", "func1", newName)
intercept[AnalysisException] { catalog.getFunction("db2", "func1") }
assert(catalog.getFunction("db2", newName).name == newName)
assert(catalog.getFunction("db2", newName).name.funcName == newName)
assert(catalog.getFunction("db2", newName).className == funcClass)
intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.security.UserGroupInformation

import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.Expression
Expand Down Expand Up @@ -545,13 +545,13 @@ private[hive] class HiveClientImpl(
}

override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState {
val catalogFunc = getFunction(db, oldName).copy(name = newName)
val catalogFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
val hiveFunc = toHiveFunction(catalogFunc, db)
client.alterFunction(db, oldName, hiveFunc)
}

override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState {
client.alterFunction(db, func.name, toHiveFunction(func, db))
client.alterFunction(db, func.name.funcName, toHiveFunction(func, db))
}

override def getFunctionOption(
Expand Down Expand Up @@ -612,7 +612,7 @@ private[hive] class HiveClientImpl(

private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = {
new HiveFunction(
f.name,
f.name.funcName,
db,
f.className,
null,
Expand All @@ -623,7 +623,8 @@ private[hive] class HiveClientImpl(
}

private def fromHiveFunction(hf: HiveFunction): CatalogFunction = {
new CatalogFunction(hf.getFunctionName, hf.getClassName)
val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName))
new CatalogFunction(name, hf.getClassName)
}

private def toHiveColumn(c: CatalogColumn): FieldSchema = {
Expand Down

0 comments on commit 2118212

Please sign in to comment.