diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/catalog/Catalog.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/catalog/Catalog.scala index b56a158c75..bacb4a61aa 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/catalog/Catalog.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/catalog/Catalog.scala @@ -51,6 +51,10 @@ trait Catalog extends LogSupport { } def listFunctions: Seq[SQLFunction] + + def updateTableSchema(database: String, table: String, schema: Catalog.TableSchema): Unit + def updateTableProperties(database: String, table: String, properties: Map[String, Any]): Unit + def updateDatabaseProperties(database: String, properties: Map[String, Any]): Unit } //case class DatabaseIdentifier(database: String, catalog: Option[String]) @@ -115,6 +119,12 @@ class InMemoryCatalog(val catalogName: String, val namespace: Option[String], fu private case class DatabaseHolder(db: Catalog.Database) { // table name -> table holder val tables = collection.mutable.Map.empty[String, Catalog.Table] + + def updateDatabase(database: Catalog.Database): DatabaseHolder = { + val newDb = DatabaseHolder(database) + newDb.tables ++= tables + newDb + } } override def listDatabases: Seq[String] = { @@ -219,12 +229,36 @@ class InMemoryCatalog(val catalogName: String, val namespace: Option[String], fu } } - // def findTable(database: String, tableName: String): Option[CatalogTable] = { - // databases.find(x => x.db == database && x.name == tableName) - // } - // + override def listFunctions: Seq[SQLFunction] = functions + + private def updateTable(database: String, table: String)(updater: Catalog.Table => Catalog.Table): Unit = { + synchronized { + val d = getDatabaseHolder(database) + d.tables.get(table) match { + case Some(oldTbl) => + d.tables += table -> updater(oldTbl) + case None => + throw SQLErrorCode.TableNotFound.newException(s"table ${database}.${table} is not found", None) + } + } + } + + override def updateTableSchema(database: String, table: String, schema: Catalog.TableSchema): Unit = { + updateTable(database, table)(tbl => tbl.copy(schema = schema)) + } - // } + override def updateTableProperties(database: String, table: String, properties: Map[String, Any]): Unit = { + updateTable(database, table)(tbl => tbl.copy(properties = properties)) + } - override def listFunctions: Seq[SQLFunction] = functions + override def updateDatabaseProperties(database: String, properties: Map[String, Any]): Unit = { + synchronized { + databases.get(database) match { + case Some(db) => + databases += database -> db.updateDatabase(db.db.copy(properties = properties)) + case None => + throw SQLErrorCode.DatabaseNotFound.newException(s"database ${database} is not found", None) + } + } + } } diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/catalog/InMemoryCatalogTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/catalog/InMemoryCatalogTest.scala new file mode 100644 index 0000000000..a67cc649a5 --- /dev/null +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/catalog/InMemoryCatalogTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.sql.catalog + +import wvlet.airframe.sql.SQLError +import wvlet.airframe.sql.catalog.Catalog.{CreateMode, TableSchema} +import wvlet.airspec.AirSpec + +class InMemoryCatalogTest extends AirSpec { + test("update database") { + val c = new InMemoryCatalog("global", None, Nil) + + c.createDatabase(Catalog.Database("default"), CreateMode.CREATE_IF_NOT_EXISTS) + c.updateDatabaseProperties("default", Map("owner" -> "xxx")) + val db = c.getDatabase("default") + + db.name shouldBe "default" + db.properties shouldBe Map("owner" -> "xxx") + + test("forbid updating missing database") { + intercept[SQLError] { + c.updateDatabaseProperties("default2", Map.empty) + } + } + } + + test("update table") { + val c = new InMemoryCatalog("global", None, Nil) + c.createDatabase(Catalog.Database("default"), CreateMode.CREATE_IF_NOT_EXISTS) + c.createTable( + Catalog.Table(Some("default"), "sample", TableSchema(Seq(Catalog.TableColumn("id", DataType.StringType)))), + CreateMode.CREATE_IF_NOT_EXISTS + ) + + test("update table properties") { + c.updateTableProperties("default", "sample", Map("table_type" -> "mapping")) + + val tbl = c.getTable("default", "sample") + tbl.name shouldBe "sample" + tbl.properties shouldBe Map("table_type" -> "mapping") + } + + test("forbid updating missing table properties") { + intercept[SQLError] { + c.updateTableProperties("default", "sample2", Map.empty) + } + } + + test("update table schema") { + val t = c.getTable("default", "sample") + c.updateTableSchema( + "default", + "sample", + TableSchema(Seq(Catalog.TableColumn("id", DataType.StringType, Map("tag" -> "pid")))) + ) + + val tbl = c.getTable("default", "sample") + tbl.name shouldBe "sample" + tbl.schema shouldBe TableSchema(Seq(Catalog.TableColumn("id", DataType.StringType, Map("tag" -> "pid")))) + } + test("forbid updating missing table schema") { + intercept[SQLError] { + c.updateTableSchema( + "default", + "sample2", + TableSchema(Seq(Catalog.TableColumn("id", DataType.StringType, Map("tag2" -> "pid")))) + ) + } + } + } +}