Skip to content

Commit

Permalink
Glue: Default CatalogId should be the AccountID (#6864)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Sep 29, 2023
1 parent 5563e62 commit 0698258
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
23 changes: 18 additions & 5 deletions moto/glue/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def create_database(
if database_name in self.databases:
raise DatabaseAlreadyExistsException()

database = FakeDatabase(database_name, database_input)
database = FakeDatabase(
database_name, database_input, catalog_id=self.account_id
)
self.databases[database_name] = database
return database

Expand Down Expand Up @@ -165,7 +167,9 @@ def create_table(
if table_name in database.tables:
raise TableAlreadyExistsException()

table = FakeTable(database_name, table_name, table_input)
table = FakeTable(
database_name, table_name, table_input, catalog_id=self.account_id
)
database.tables[table_name] = table
return table

Expand Down Expand Up @@ -1041,9 +1045,12 @@ def batch_get_triggers(self, trigger_names: List[str]) -> List[Dict[str, Any]]:


class FakeDatabase(BaseModel):
def __init__(self, database_name: str, database_input: Dict[str, Any]):
def __init__(
self, database_name: str, database_input: Dict[str, Any], catalog_id: str
):
self.name = database_name
self.input = database_input
self.catalog_id = catalog_id
self.created_time = utcnow()
self.tables: Dict[str, FakeTable] = OrderedDict()

Expand All @@ -1058,16 +1065,21 @@ def as_dict(self) -> Dict[str, Any]:
"CreateTableDefaultPermissions"
),
"TargetDatabase": self.input.get("TargetDatabase"),
"CatalogId": self.input.get("CatalogId"),
"CatalogId": self.input.get("CatalogId") or self.catalog_id,
}


class FakeTable(BaseModel):
def __init__(
self, database_name: str, table_name: str, table_input: Dict[str, Any]
self,
database_name: str,
table_name: str,
table_input: Dict[str, Any],
catalog_id: str,
):
self.database_name = database_name
self.name = table_name
self.catalog_id = catalog_id
self.partitions: Dict[str, FakePartition] = OrderedDict()
self.created_time = utcnow()
self.updated_time: Optional[datetime] = None
Expand Down Expand Up @@ -1104,6 +1116,7 @@ def as_dict(self, version: Optional[str] = None) -> Dict[str, Any]:
**self.get_version(str(version)),
# Add VersionId after we get the version-details, just to make sure that it's a valid version (int)
"VersionId": str(version),
"CatalogId": self.catalog_id,
}
if self.updated_time is not None:
obj["UpdateTime"] = unix_time(self.updated_time)
Expand Down
12 changes: 7 additions & 5 deletions tests/test_glue/test_datacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def test_create_database():
response = helpers.get_database(client, database_name)
database = response["Database"]

assert database.get("Name") == database_name
assert database["Name"] == database_name
assert database["CatalogId"] == ACCOUNT_ID
assert database.get("Description") == database_input.get("Description")
assert database.get("LocationUri") == database_input.get("LocationUri")
assert database.get("Parameters") == database_input.get("Parameters")
Expand Down Expand Up @@ -67,14 +68,11 @@ def test_get_database_not_exits():


@mock_glue
def test_get_databases_empty():
def test_get_databases():
client = boto3.client("glue", region_name="us-east-1")
response = client.get_databases()
assert len(response["DatabaseList"]) == 0


@mock_glue
def test_get_databases_several_items():
client = boto3.client("glue", region_name="us-east-1")
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"

Expand All @@ -86,7 +84,9 @@ def test_get_databases_several_items():
)
assert len(database_list) == 2
assert database_list[0]["Name"] == database_name_1
assert database_list[0]["CatalogId"] == ACCOUNT_ID
assert database_list[1]["Name"] == database_name_2
assert database_list[1]["CatalogId"] == ACCOUNT_ID


@mock_glue
Expand Down Expand Up @@ -222,6 +222,7 @@ def test_get_tables():
table["StorageDescriptor"] == table_inputs[table_name]["StorageDescriptor"]
)
assert table["PartitionKeys"] == table_inputs[table_name]["PartitionKeys"]
assert table["CatalogId"] == ACCOUNT_ID


@mock_glue
Expand Down Expand Up @@ -319,6 +320,7 @@ def test_get_table_versions():
table = client.get_table(DatabaseName=database_name, Name=table_name)["Table"]
assert table["StorageDescriptor"]["Columns"] == []
assert table["VersionId"] == "1"
assert table["CatalogId"] == ACCOUNT_ID

columns = [{"Name": "country", "Type": "string"}]
table_input = helpers.create_table_input(database_name, table_name, columns=columns)
Expand Down

0 comments on commit 0698258

Please sign in to comment.