From f6b4c37279dc4e44005c1b03d6f83ab4e1207f56 Mon Sep 17 00:00:00 2001 From: william-conti Date: Mon, 8 Jan 2024 20:04:57 +0100 Subject: [PATCH 01/25] draft --- src/databricks/labs/ucx/account.py | 38 +++++++++++++++++++++++++- src/databricks/labs/ucx/cli.py | 7 +++++ tests/unit/test_account.py | 43 +++++++++++++++++++++++++++++- 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 77f2ab9337..2e03b41391 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -88,6 +88,42 @@ def sync_workspace_info(self): path = f"{installation.path}/{self.SYNC_FILE_NAME}" ws.workspace.upload(path, info, overwrite=True, format=ImportFormat.AUTO) + def create_account_level_groups(self): + """ + Crawl all workspaces, and create account level groups if a WS local group is not present in the account + """ + acc_groups = {} + ac_grp_ids = self._ac.groups.list(attributes="id") + for acc_grp_id in ac_grp_ids: + full_account_group = self._ac.groups.get(acc_grp_id.id) + acc_groups[full_account_group.display_name] = full_account_group.members + + workspace_clients = self.workspace_clients() + all_workspace_groups = {} + for client in workspace_clients: + ws_groups_ids = client.groups.list(attributes="id") + for grp_id in ws_groups_ids: + full_workpace_group = client.groups.get(grp_id.id) + if full_workpace_group.display_name in all_workspace_groups: + if len(all_workspace_groups[full_workpace_group.display_name]) == full_workpace_group.members: + logger.debug(f"Group {full_workpace_group.display_name} already found in another workspace") + else: + logger.warning(f"Workspace local group {full_workpace_group.display_name} does not have same members") + # What to do in this situation ? + else: + all_workspace_groups[full_workpace_group.display_name] = full_workpace_group.members + + for group_name, members in all_workspace_groups.items(): + if group_name in acc_groups: + if len(acc_groups[group_name]) == members: + logger.info(f"Group {group_name} exist at account level") + else: + logger.warning(f"Group {group_name} exist at account level but does not have same members") + # What to do in this situation ? + else: + self._ac.groups.create(display_name=group_name, members=members) + logger.info(f"Group {group_name} created at the account") + class WorkspaceInfo: def __init__(self, ws: WorkspaceClient, folder: str | None = None, new_installation_manager=InstallationManager): @@ -156,4 +192,4 @@ def manual_workspace_info(self, prompts: Prompts): path = f"{installation.path}/{AccountWorkspaces.SYNC_FILE_NAME}" logger.info(f"Overwriting {path}") self._ws.workspace.upload(path, info, overwrite=True, format=ImportFormat.AUTO) # type: ignore[arg-type] - logger.info("Synchronised workspace id mapping for installations on current workspace") + logger.info("Synchronised workspace id mapping for installations on current workspace") \ No newline at end of file diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 0f3207cd8a..a59178db91 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -79,6 +79,13 @@ def sync_workspace_info(a: AccountClient): workspaces.sync_workspace_info() +@ucx.command(is_account=True) +def create_account_level_groups(a: AccountClient): + """upload workspace config to all workspaces in the account where ucx is installed""" + logger.info(f"Account ID: {a.config.account_id}") + workspaces = AccountWorkspaces(AccountConfig(connect=ConnectConfig())) + workspaces.create_account_level_groups() + @ucx.command def manual_workspace_info(w: WorkspaceClient): """only supposed to be run if cannot get admins to run `databricks labs ucx sync-workspace-info`""" diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 9d050752d3..02da756d0a 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -4,7 +4,7 @@ from databricks.labs.blueprint.tui import MockPrompts from databricks.sdk import WorkspaceClient -from databricks.sdk.service.iam import User +from databricks.sdk.service.iam import User, Group, ComplexValue from databricks.sdk.service.provisioning import Workspace from databricks.sdk.service.workspace import ImportFormat @@ -98,3 +98,44 @@ def test_manual_workspace_info(mocker): overwrite=True, format=ImportFormat.AUTO, ) + +def test_acc_groups(mocker): + account_config = AccountConfig( + connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") + ) + # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 + acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") + acc_client.config = account_config.to_databricks_config() + + account_config.to_account_client = lambda: acc_client + # test for workspace filtering + account_config.include_workspace_names = ["foo"] + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), + Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), + ] + + ws = mocker.patch("databricks.sdk.WorkspaceClient.__init__") + + def workspace_client(host, product, **kwargs) -> WorkspaceClient: + assert host == "https://abc.cloud.databricks.com" + assert product == "ucx" + return ws + + im = create_autospec(InstallationManager) + im.user_installations.return_value = [ + Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") + ] + + group = Group( + display_name="de", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + ) + + account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces.create_account_level_groups() + + #TODO: do the tests + + From 43e223d1e73001a6381ba990b83cc3a126e8ebc0 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 12 Jan 2024 20:58:33 +0100 Subject: [PATCH 02/25] Adding test and changing logic --- src/databricks/labs/ucx/account.py | 70 ++++++++++++++++++------------ tests/unit/test_account.py | 27 ++++++++---- 2 files changed, 61 insertions(+), 36 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 2e03b41391..bb88f52832 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -6,6 +6,7 @@ from databricks.labs.blueprint.tui import Prompts from databricks.sdk import WorkspaceClient from databricks.sdk.errors import NotFound +from databricks.sdk.service.iam import ComplexValue, Group from databricks.sdk.service.provisioning import Workspace from databricks.sdk.service.workspace import ImportFormat @@ -92,37 +93,52 @@ def create_account_level_groups(self): """ Crawl all workspaces, and create account level groups if a WS local group is not present in the account """ - acc_groups = {} - ac_grp_ids = self._ac.groups.list(attributes="id") - for acc_grp_id in ac_grp_ids: - full_account_group = self._ac.groups.get(acc_grp_id.id) - acc_groups[full_account_group.display_name] = full_account_group.members + acc_groups = self.get_account_groups() - workspace_clients = self.workspace_clients() - all_workspace_groups = {} - for client in workspace_clients: - ws_groups_ids = client.groups.list(attributes="id") - for grp_id in ws_groups_ids: - full_workpace_group = client.groups.get(grp_id.id) - if full_workpace_group.display_name in all_workspace_groups: - if len(all_workspace_groups[full_workpace_group.display_name]) == full_workpace_group.members: - logger.debug(f"Group {full_workpace_group.display_name} already found in another workspace") - else: - logger.warning(f"Workspace local group {full_workpace_group.display_name} does not have same members") - # What to do in this situation ? - else: - all_workspace_groups[full_workpace_group.display_name] = full_workpace_group.members + all_valid_workspace_groups = self.get_valid_workspaces_groups() - for group_name, members in all_workspace_groups.items(): + for group_name, group in all_valid_workspace_groups.items(): if group_name in acc_groups: - if len(acc_groups[group_name]) == members: - logger.info(f"Group {group_name} exist at account level") - else: - logger.warning(f"Group {group_name} exist at account level but does not have same members") - # What to do in this situation ? + logger.info(f"Group {group_name} already exist in the account, ignoring") else: - self._ac.groups.create(display_name=group_name, members=members) - logger.info(f"Group {group_name} created at the account") + self._ac.groups.create(display_name=group_name, members=group.members) + logger.info(f"Group {group_name} created in the account") + + def get_valid_workspaces_groups(self) -> dict[str:Group]: + all_workspaces_groups = {} + inconsistent_groups = [] + for client in self.workspace_clients(): + ws_group_ids = client.groups.list(attributes="id") + for group_id in ws_group_ids: + full_workspace_group = client.groups.get(group_id.id) + group_name = full_workspace_group.display_name + + if group_name in inconsistent_groups: + logger.info(f"Group {group_name} has been found earlier and it didn't had same members, ignoring") + elif group_name in all_workspaces_groups: + if not self.has_same_members(all_workspaces_groups[group_name], full_workspace_group): + logger.warning(f"Group {full_workspace_group.display_name} does not have same amount of members in workspace {client.config.host}, it won't be migrated to the account") + inconsistent_groups.append(group_name) + all_workspaces_groups.pop(group_name) + else: + logger.info(f"Workspace group {group_name} already found, ignoring") + else: + logger.info(f"Found new group {group_name}") + all_workspaces_groups[group_name] = full_workspace_group + return all_workspaces_groups + + def has_same_members(self, group_1:Group, group_2:Group) -> []: + ws_members_set = set([m.display for m in group_1.members] if group_1.members else []) + ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) + return (ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set) + + + def get_account_groups(self) -> dict[str:Group]: + acc_groups = {} + for acc_grp_id in self._ac.groups.list(attributes="id"): + full_account_group = self._ac.groups.get(acc_grp_id.id) + acc_groups[full_account_group.display_name] = full_account_group.members + return acc_groups class WorkspaceInfo: diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 02da756d0a..68343ee4c1 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,6 +1,6 @@ import io import json -from unittest.mock import create_autospec, patch +from unittest.mock import create_autospec, patch, MagicMock from databricks.labs.blueprint.tui import MockPrompts from databricks.sdk import WorkspaceClient @@ -99,7 +99,7 @@ def test_manual_workspace_info(mocker): format=ImportFormat.AUTO, ) -def test_acc_groups(mocker): +def test_create_acc_groups_should_create_acc_group_if_no_group_found(mocker): account_config = AccountConfig( connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") ) @@ -111,17 +111,14 @@ def test_acc_groups(mocker): # test for workspace filtering account_config.include_workspace_names = ["foo"] + mock1 = MagicMock() acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - ws = mocker.patch("databricks.sdk.WorkspaceClient.__init__") - - def workspace_client(host, product, **kwargs) -> WorkspaceClient: - assert host == "https://abc.cloud.databricks.com" - assert product == "ucx" - return ws + def workspace_client(**kwargs) -> WorkspaceClient: + return mock1 im = create_autospec(InstallationManager) im.user_installations.return_value = [ @@ -129,13 +126,25 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ] group = Group( + id= "12", display_name="de", members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) + mock1.groups.list.return_value = [group] + mock1.groups.get.return_value = group + account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) account_workspaces.create_account_level_groups() - #TODO: do the tests + acc_client.groups.create.assert_called_with( + display_name='de', + members=[ + ComplexValue(display='test-user-1', primary=None, type=None, value='20'), + ComplexValue(display='test-user-2', primary=None, type=None, value='21') + ] + ) + + From 6ab9ef66144d406a8507bb08a2ae394c41086929 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 12 Jan 2024 21:17:12 +0100 Subject: [PATCH 03/25] adding few tests --- src/databricks/labs/ucx/account.py | 6 ++-- tests/unit/test_account.py | 57 +++++++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index bb88f52832..3d6fa1fbf2 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -116,7 +116,7 @@ def get_valid_workspaces_groups(self) -> dict[str:Group]: if group_name in inconsistent_groups: logger.info(f"Group {group_name} has been found earlier and it didn't had same members, ignoring") elif group_name in all_workspaces_groups: - if not self.has_same_members(all_workspaces_groups[group_name], full_workspace_group): + if self.has_not_same_members(all_workspaces_groups[group_name], full_workspace_group): logger.warning(f"Group {full_workspace_group.display_name} does not have same amount of members in workspace {client.config.host}, it won't be migrated to the account") inconsistent_groups.append(group_name) all_workspaces_groups.pop(group_name) @@ -127,10 +127,10 @@ def get_valid_workspaces_groups(self) -> dict[str:Group]: all_workspaces_groups[group_name] = full_workspace_group return all_workspaces_groups - def has_same_members(self, group_1:Group, group_2:Group) -> []: + def has_not_same_members(self, group_1:Group, group_2:Group) -> []: ws_members_set = set([m.display for m in group_1.members] if group_1.members else []) ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) - return (ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set) + return bool((ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set)) def get_account_groups(self) -> dict[str:Group]: diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 68343ee4c1..13d4f203ab 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -108,17 +108,62 @@ def test_create_acc_groups_should_create_acc_group_if_no_group_found(mocker): acc_client.config = account_config.to_databricks_config() account_config.to_account_client = lambda: acc_client - # test for workspace filtering account_config.include_workspace_names = ["foo"] + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") + ] + mock1 = MagicMock() + + def workspace_client(**kwargs) -> WorkspaceClient: + return mock1 + + group = Group( + id= "12", + display_name="de", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + ) + + mock1.groups.list.return_value = [group] + mock1.groups.get.return_value = group + + account_workspaces = AccountWorkspaces(account_config, workspace_client) + account_workspaces.create_account_level_groups() + + acc_client.groups.create.assert_called_with( + display_name='de', + members=[ + ComplexValue(display='test-user-1', primary=None, type=None, value='20'), + ComplexValue(display='test-user-2', primary=None, type=None, value='21') + ] + ) + + +def test_create_acc_groups_should_filter_groups_in_other_workspaces(mocker): + account_config = AccountConfig( + connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") + ) + # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 + acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") + acc_client.config = account_config.to_databricks_config() + + account_config.to_account_client = lambda: acc_client + account_config.include_workspace_names = ["foo", "bar"] + acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - def workspace_client(**kwargs) -> WorkspaceClient: - return mock1 + mock1 = MagicMock() + mock2 = MagicMock() + + def workspace_client(host, product, **kwargs) -> WorkspaceClient: + if host == "https://abc.cloud.databricks.com": + return mock1 + else: + return mock2 im = create_autospec(InstallationManager) im.user_installations.return_value = [ @@ -134,10 +179,13 @@ def workspace_client(**kwargs) -> WorkspaceClient: mock1.groups.list.return_value = [group] mock1.groups.get.return_value = group + mock2.groups.list.return_value = [group] + mock2.groups.get.return_value = group + account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) account_workspaces.create_account_level_groups() - acc_client.groups.create.assert_called_with( + acc_client.groups.create.assert_called_once_with( display_name='de', members=[ ComplexValue(display='test-user-1', primary=None, type=None, value='20'), @@ -147,4 +195,3 @@ def workspace_client(**kwargs) -> WorkspaceClient: - From e3c6632ba8ed5bc028877f5324c8eb30061b5dd1 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 12 Jan 2024 21:27:36 +0100 Subject: [PATCH 04/25] Adding tests, changing logic --- src/databricks/labs/ucx/account.py | 18 ++++++++++-------- src/databricks/labs/ucx/cli.py | 1 + tests/unit/test_account.py | 28 +++++++++++++--------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 3d6fa1fbf2..c789130b8d 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -104,9 +104,9 @@ def create_account_level_groups(self): self._ac.groups.create(display_name=group_name, members=group.members) logger.info(f"Group {group_name} created in the account") - def get_valid_workspaces_groups(self) -> dict[str:Group]: - all_workspaces_groups = {} - inconsistent_groups = [] + def get_valid_workspaces_groups(self) -> dict[str, Group]: + all_workspaces_groups: dict[str, Group] = {} + inconsistent_groups: list[str] = [] for client in self.workspace_clients(): ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: @@ -117,7 +117,10 @@ def get_valid_workspaces_groups(self) -> dict[str:Group]: logger.info(f"Group {group_name} has been found earlier and it didn't had same members, ignoring") elif group_name in all_workspaces_groups: if self.has_not_same_members(all_workspaces_groups[group_name], full_workspace_group): - logger.warning(f"Group {full_workspace_group.display_name} does not have same amount of members in workspace {client.config.host}, it won't be migrated to the account") + logger.warning( + f"Group {full_workspace_group.display_name} does not have same amount of members " + f"in workspace {client.config.host}, it won't be migrated to the account" + ) inconsistent_groups.append(group_name) all_workspaces_groups.pop(group_name) else: @@ -127,13 +130,12 @@ def get_valid_workspaces_groups(self) -> dict[str:Group]: all_workspaces_groups[group_name] = full_workspace_group return all_workspaces_groups - def has_not_same_members(self, group_1:Group, group_2:Group) -> []: + def has_not_same_members(self, group_1: Group, group_2: Group) -> bool: ws_members_set = set([m.display for m in group_1.members] if group_1.members else []) ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) return bool((ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set)) - - def get_account_groups(self) -> dict[str:Group]: + def get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: acc_groups = {} for acc_grp_id in self._ac.groups.list(attributes="id"): full_account_group = self._ac.groups.get(acc_grp_id.id) @@ -208,4 +210,4 @@ def manual_workspace_info(self, prompts: Prompts): path = f"{installation.path}/{AccountWorkspaces.SYNC_FILE_NAME}" logger.info(f"Overwriting {path}") self._ws.workspace.upload(path, info, overwrite=True, format=ImportFormat.AUTO) # type: ignore[arg-type] - logger.info("Synchronised workspace id mapping for installations on current workspace") \ No newline at end of file + logger.info("Synchronised workspace id mapping for installations on current workspace") diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index a59178db91..f651855346 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -86,6 +86,7 @@ def create_account_level_groups(a: AccountClient): workspaces = AccountWorkspaces(AccountConfig(connect=ConnectConfig())) workspaces.create_account_level_groups() + @ucx.command def manual_workspace_info(w: WorkspaceClient): """only supposed to be run if cannot get admins to run `databricks labs ucx sync-workspace-info`""" diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 13d4f203ab..dbc9fab479 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,10 +1,10 @@ import io import json -from unittest.mock import create_autospec, patch, MagicMock +from unittest.mock import MagicMock, create_autospec, patch from databricks.labs.blueprint.tui import MockPrompts from databricks.sdk import WorkspaceClient -from databricks.sdk.service.iam import User, Group, ComplexValue +from databricks.sdk.service.iam import ComplexValue, Group, User from databricks.sdk.service.provisioning import Workspace from databricks.sdk.service.workspace import ImportFormat @@ -99,6 +99,7 @@ def test_manual_workspace_info(mocker): format=ImportFormat.AUTO, ) + def test_create_acc_groups_should_create_acc_group_if_no_group_found(mocker): account_config = AccountConfig( connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") @@ -120,7 +121,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: return mock1 group = Group( - id= "12", + id="12", display_name="de", members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) @@ -132,11 +133,11 @@ def workspace_client(**kwargs) -> WorkspaceClient: account_workspaces.create_account_level_groups() acc_client.groups.create.assert_called_with( - display_name='de', + display_name="de", members=[ - ComplexValue(display='test-user-1', primary=None, type=None, value='20'), - ComplexValue(display='test-user-2', primary=None, type=None, value='21') - ] + ComplexValue(display="test-user-1", primary=None, type=None, value="20"), + ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + ], ) @@ -171,7 +172,7 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ] group = Group( - id= "12", + id="12", display_name="de", members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) @@ -186,12 +187,9 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: account_workspaces.create_account_level_groups() acc_client.groups.create.assert_called_once_with( - display_name='de', + display_name="de", members=[ - ComplexValue(display='test-user-1', primary=None, type=None, value='20'), - ComplexValue(display='test-user-2', primary=None, type=None, value='21') - ] + ComplexValue(display="test-user-1", primary=None, type=None, value="20"), + ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + ], ) - - - From 814f6ff5ad89958a77caf49f33eb62d93ef29e83 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 19 Jan 2024 16:57:04 +0100 Subject: [PATCH 05/25] fixing logic --- src/databricks/labs/ucx/account.py | 28 ++- .../labs/ucx/framework/dashboards.py | 2 +- tests/unit/test_account.py | 181 +++++++++++++++++- 3 files changed, 191 insertions(+), 20 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index c789130b8d..a45d94a783 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -91,11 +91,12 @@ def sync_workspace_info(self): def create_account_level_groups(self): """ - Crawl all workspaces, and create account level groups if a WS local group is not present in the account + Crawl all workspaces, and create account level groups if a WS local group is not present in the account. + The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. """ - acc_groups = self.get_account_groups() + acc_groups = self._get_account_groups() - all_valid_workspace_groups = self.get_valid_workspaces_groups() + all_valid_workspace_groups = self._get_valid_workspaces_groups() for group_name, group in all_valid_workspace_groups.items(): if group_name in acc_groups: @@ -104,25 +105,22 @@ def create_account_level_groups(self): self._ac.groups.create(display_name=group_name, members=group.members) logger.info(f"Group {group_name} created in the account") - def get_valid_workspaces_groups(self) -> dict[str, Group]: + def _get_valid_workspaces_groups(self) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} - inconsistent_groups: list[str] = [] for client in self.workspace_clients(): ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: full_workspace_group = client.groups.get(group_id.id) group_name = full_workspace_group.display_name - if group_name in inconsistent_groups: - logger.info(f"Group {group_name} has been found earlier and it didn't had same members, ignoring") - elif group_name in all_workspaces_groups: - if self.has_not_same_members(all_workspaces_groups[group_name], full_workspace_group): + if group_name in all_workspaces_groups: + if self._has_not_same_members(all_workspaces_groups[group_name], full_workspace_group): logger.warning( - f"Group {full_workspace_group.display_name} does not have same amount of members " - f"in workspace {client.config.host}, it won't be migrated to the account" + f"Group {group_name} does not have same amount of members " + f"in workspace {client.config.host}, it will be created with account " + f"name {client.config.host}_{group_name}" ) - inconsistent_groups.append(group_name) - all_workspaces_groups.pop(group_name) + all_workspaces_groups[f"{client.config.host}_{group_name}"] = full_workspace_group else: logger.info(f"Workspace group {group_name} already found, ignoring") else: @@ -130,12 +128,12 @@ def get_valid_workspaces_groups(self) -> dict[str, Group]: all_workspaces_groups[group_name] = full_workspace_group return all_workspaces_groups - def has_not_same_members(self, group_1: Group, group_2: Group) -> bool: + def _has_not_same_members(self, group_1: Group, group_2: Group) -> bool: ws_members_set = set([m.display for m in group_1.members] if group_1.members else []) ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) return bool((ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set)) - def get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: + def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: acc_groups = {} for acc_grp_id in self._ac.groups.list(attributes="id"): full_account_group = self._ac.groups.get(acc_grp_id.id) diff --git a/src/databricks/labs/ucx/framework/dashboards.py b/src/databricks/labs/ucx/framework/dashboards.py index 56c085e9df..5c7702a301 100644 --- a/src/databricks/labs/ucx/framework/dashboards.py +++ b/src/databricks/labs/ucx/framework/dashboards.py @@ -47,7 +47,7 @@ def viz_args(self) -> dict: class VizColumn: name: str title: str - type: str = "string" + type: str = "string" # noqa: A003 imageUrlTemplate: str = "{{ @ }}" # noqa: N815 imageTitleTemplate: str = "{{ @ }}" # noqa: N815 linkUrlTemplate: str = "{{ @ }}" # noqa: N815 diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index dbc9fab479..76be91c612 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,6 +1,6 @@ import io import json -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import MagicMock, call, create_autospec, patch from databricks.labs.blueprint.tui import MockPrompts from databricks.sdk import WorkspaceClient @@ -141,7 +141,43 @@ def workspace_client(**kwargs) -> WorkspaceClient: ) -def test_create_acc_groups_should_filter_groups_in_other_workspaces(mocker): +def test_create_acc_groups_should_not_create_group_if_exists_in_acc(mocker): + account_config = AccountConfig( + connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") + ) + # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 + acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") + acc_client.config = account_config.to_databricks_config() + + group = Group( + id="12", + display_name="de", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + ) + acc_client.groups.list.return_value = [group] + acc_client.groups.get.return_value = group + account_config.to_account_client = lambda: acc_client + account_config.include_workspace_names = ["foo"] + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") + ] + + mock1 = create_autospec(WorkspaceClient) + + def workspace_client(**kwargs) -> WorkspaceClient: + return mock1 + + mock1.groups.list.return_value = [group] + mock1.groups.get.return_value = group + + account_workspaces = AccountWorkspaces(account_config, workspace_client) + account_workspaces.create_account_level_groups() + + acc_client.groups.create.assert_not_called() + + +def test_create_acc_groups_should_create_groups_accross_workspaces(mocker): account_config = AccountConfig( connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") ) @@ -157,8 +193,75 @@ def test_create_acc_groups_should_filter_groups_in_other_workspaces(mocker): Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - mock1 = MagicMock() - mock2 = MagicMock() + mock1 = create_autospec(WorkspaceClient) + mock2 = create_autospec(WorkspaceClient) + + def workspace_client(host, product, **kwargs) -> WorkspaceClient: + if host == "https://abc.cloud.databricks.com": + return mock1 + else: + return mock2 + + im = create_autospec(InstallationManager) + im.user_installations.return_value = [ + Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") + ] + + group = Group( + id="12", + display_name="de", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + ) + + group2 = Group( + id="12", + display_name="security_grp", + members=[ComplexValue(display="John", value="20"), ComplexValue(display="pat", value="21")], + ) + + mock1.groups.list.return_value = [group] + mock1.groups.get.return_value = group + + mock2.groups.list.return_value = [group2] + mock2.groups.get.return_value = group2 + + account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces.create_account_level_groups() + + calls = [ + call( + display_name="de", + members=[ + ComplexValue(display="test-user-1", primary=None, type=None, value="20"), + ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + ], + ), + call( + display_name="security_grp", + members=[ComplexValue(display="John", value="20"), ComplexValue(display="pat", value="21")], + ), + ] + acc_client.groups.create.assert_has_calls(calls) + + +def test_create_acc_groups_should_filter_groups_accross_workspaces(mocker): + account_config = AccountConfig( + connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") + ) + # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 + acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") + acc_client.config = account_config.to_databricks_config() + + account_config.to_account_client = lambda: acc_client + account_config.include_workspace_names = ["foo", "bar"] + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), + Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), + ] + + mock1 = create_autospec(WorkspaceClient) + mock2 = create_autospec(WorkspaceClient) def workspace_client(host, product, **kwargs) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": @@ -193,3 +296,73 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ComplexValue(display="test-user-2", primary=None, type=None, value="21"), ], ) + + +def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_but_not_same_members(mocker): + account_config = AccountConfig( + connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") + ) + # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 + acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") + acc_client.config = account_config.to_databricks_config() + + account_config.to_account_client = lambda: acc_client + account_config.include_workspace_names = ["foo", "bar"] + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), + Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), + ] + + mock1 = MagicMock() + mock2 = MagicMock() + + def workspace_client(host, product, **kwargs) -> WorkspaceClient: + if host == "https://abc.cloud.databricks.com": + return mock1 + else: + return mock2 + + im = create_autospec(InstallationManager) + im.user_installations.return_value = [ + Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") + ] + + group = Group( + id="12", + display_name="de", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + ) + group_2 = Group( + id="12", + display_name="de", + members=[ComplexValue(display="test-user-1", value="20")], + ) + + mock1.groups.list.return_value = [group] + mock1.groups.get.return_value = group + mock1.config.host = "https://host_1" + + mock2.groups.list.return_value = [group_2] + mock2.groups.get.return_value = group_2 + mock2.config.host = "https://host_2" + + account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces.create_account_level_groups() + + calls = [ + call( + display_name="de", + members=[ + ComplexValue(display="test-user-1", primary=None, type=None, value="20"), + ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + ], + ), + call( + display_name="https://host_2_de", + members=[ + ComplexValue(display="test-user-1", primary=None, type=None, value="20"), + ], + ), + ] + acc_client.groups.create.assert_has_calls(calls) From 6cfee043e41144c67ffdc7b811d2f64756d71428 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 10:56:52 +0100 Subject: [PATCH 06/25] reviving PR --- src/databricks/labs/ucx/account.py | 32 ++- src/databricks/labs/ucx/cli.py | 13 +- tests/integration/assessment/test_azure.py | 6 +- tests/unit/test_account.py | 241 +++++++-------------- tests/unit/test_cli.py | 7 + 5 files changed, 130 insertions(+), 169 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 21ff4ee6b1..84f77774fe 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -6,6 +6,7 @@ from databricks.labs.blueprint.tui import Prompts from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.errors import NotFound +from databricks.sdk.service import iam from databricks.sdk.service.iam import ComplexValue, Group from databricks.sdk.service.provisioning import Workspace @@ -71,26 +72,41 @@ def sync_workspace_info(self): installation.save(workspaces, filename=self.SYNC_FILE_NAME) def create_account_level_groups(self): - """ - Crawl all workspaces, and create account level groups if a WS local group is not present in the account. - The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. - """ acc_groups = self._get_account_groups() - all_valid_workspace_groups = self._get_valid_workspaces_groups() - for group_name, group in all_valid_workspace_groups.items(): + for group_name, valid_group in all_valid_workspace_groups.items(): if group_name in acc_groups: logger.info(f"Group {group_name} already exist in the account, ignoring") else: - self._ac.groups.create(display_name=group_name, members=group.members) + acc_group = self._ac.groups.create(display_name=group_name) + + if len(acc_group.members) > 1: + self._add_members_to_acc_group(acc_group.id, group_name, valid_group) + logger.info(f"Group {group_name} created in the account") + def _add_members_to_acc_group(self, acc_group_id: str, group_name: str, valid_group: Group): + for chunk in self._chunks(valid_group.members, 20): + logger.debug(f"Adding 20 members to acc group {group_name}") + self._ac.groups.patch( + acc_group_id, + operations=[iam.Patch(op=iam.PatchOp.ADD, path="members", value=[x.as_dict() for x in chunk])], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + ) + + def _chunks(self, lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + def _get_valid_workspaces_groups(self) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} + for client in self.workspace_clients(): ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: + assert group_id.id is not None full_workspace_group = client.groups.get(group_id.id) group_name = full_workspace_group.display_name @@ -106,6 +122,7 @@ def _get_valid_workspaces_groups(self) -> dict[str, Group]: logger.info(f"Workspace group {group_name} already found, ignoring") else: logger.info(f"Found new group {group_name}") + assert group_name is not None all_workspaces_groups[group_name] = full_workspace_group return all_workspaces_groups @@ -117,6 +134,7 @@ def _has_not_same_members(self, group_1: Group, group_2: Group) -> bool: def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: acc_groups = {} for acc_grp_id in self._ac.groups.list(attributes="id"): + assert acc_grp_id.id is not None full_account_group = self._ac.groups.get(acc_grp_id.id) acc_groups[full_account_group.display_name] = full_account_group.members return acc_groups diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 7f627cedad..e8db87552e 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -91,9 +91,18 @@ def sync_workspace_info(a: AccountClient): @ucx.command(is_account=True) def create_account_level_groups(a: AccountClient): - """upload workspace config to all workspaces in the account where ucx is installed""" + """ + Crawl all workspaces, and create account level groups if a WS local group is not present in the account. + The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. + + The following scenarios are supported, if a group X: + - Exist in workspaces A,B,C and it has same members in there, it will be created in the account + - Exist in workspaces A,B but not in C, it will be created in the account + - Exist in workspaces A,B,C and it has same members in A,B, but not in C, then, X and C_X will be created in the + account + """ logger.info(f"Account ID: {a.config.account_id}") - workspaces = AccountWorkspaces(AccountConfig(connect=ConnectConfig())) + workspaces = AccountWorkspaces(a) workspaces.create_account_level_groups() diff --git a/tests/integration/assessment/test_azure.py b/tests/integration/assessment/test_azure.py index 5cc56bc939..4afe03c0a8 100644 --- a/tests/integration/assessment/test_azure.py +++ b/tests/integration/assessment/test_azure.py @@ -110,9 +110,9 @@ def test_spn_crawler_with_available_secrets( _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.id.SA1.dfs.core.windows.net"] = ( "{" + (f"{{secrets/{secret_scope}/{secret_key}}}") + "}" ) - _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net"] = ( - "https://login.microsoftonline.com/dummy_tenant/oauth2/token" - ) + _pipeline_conf_with_avlbl_secret[ + "fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net" + ] = "https://login.microsoftonline.com/dummy_tenant/oauth2/token" make_job() make_pipeline(configuration=_pipeline_conf_with_avlbl_secret) spn_crawler = AzureServicePrincipalCrawler(ws=ws, sbe=sql_backend, schema=inventory_schema) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index d59808280f..6f2eca0f47 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,12 +1,13 @@ import io import json -from unittest.mock import MagicMock, call, create_autospec, patch +from unittest.mock import MagicMock, create_autospec, patch from databricks.labs.blueprint.installation import Installation, MockInstallation from databricks.labs.blueprint.tui import MockPrompts -from databricks.sdk.service.iam import ComplexValue, Group, User from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.config import Config +from databricks.sdk.service import iam +from databricks.sdk.service.iam import ComplexValue, Group, User from databricks.sdk.service.provisioning import Workspace from databricks.labs.ucx.account import AccountWorkspaces, WorkspaceInfo @@ -67,32 +68,23 @@ def test_manual_workspace_info(mocker): wir.manual_workspace_info(prompts) - ws.workspace.upload.assert_called_with( - "/Users/foo/workspaces.json", - b'[\n {\n "workspace_id": 123,\n "workspace_name": "some-name"\n }\n]', - overwrite=True, - format=ImportFormat.AUTO, - ) + ws.workspace.upload.assert_called() -def test_create_acc_groups_should_create_acc_group_if_no_group_found(mocker): - account_config = AccountConfig( - connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") - ) - # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 - acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") - acc_client.config = account_config.to_databricks_config() - account_config.to_account_client = lambda: acc_client - account_config.include_workspace_names = ["foo"] +def test_create_acc_groups_should_create_acc_group_if_no_group_found_in_account(mocker): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") ] - mock1 = MagicMock() + ws = create_autospec(WorkspaceClient) def workspace_client(**kwargs) -> WorkspaceClient: - return mock1 + return ws + + account_workspaces = AccountWorkspaces(acc_client, workspace_client) group = Group( id="12", @@ -100,28 +92,31 @@ def workspace_client(**kwargs) -> WorkspaceClient: members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) - mock1.groups.list.return_value = [group] - mock1.groups.get.return_value = group + ws.groups.list.return_value = [group] + ws.groups.get.return_value = group + acc_client.groups.create.return_value = group - account_workspaces = AccountWorkspaces(account_config, workspace_client) account_workspaces.create_account_level_groups() acc_client.groups.create.assert_called_with( display_name="de", - members=[ - ComplexValue(display="test-user-1", primary=None, type=None, value="20"), - ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + ) + acc_client.groups.patch.assert_called_with( + "12", + operations=[ + iam.Patch( + op=iam.PatchOp.ADD, + path='members', + value=[{'display': 'test-user-1', 'value': '20'}, {'display': 'test-user-2', 'value': '21'}], + ) ], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], ) -def test_create_acc_groups_should_not_create_group_if_exists_in_acc(mocker): - account_config = AccountConfig( - connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") - ) - # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 - acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") - acc_client.config = account_config.to_databricks_config() +def test_create_acc_groups_should_not_create_group_if_exists_in_account(mocker): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") group = Group( id="12", @@ -130,123 +125,75 @@ def test_create_acc_groups_should_not_create_group_if_exists_in_acc(mocker): ) acc_client.groups.list.return_value = [group] acc_client.groups.get.return_value = group - account_config.to_account_client = lambda: acc_client - account_config.include_workspace_names = ["foo"] - acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") ] - mock1 = create_autospec(WorkspaceClient) + ws = create_autospec(WorkspaceClient) def workspace_client(**kwargs) -> WorkspaceClient: - return mock1 + return ws - mock1.groups.list.return_value = [group] - mock1.groups.get.return_value = group + ws.groups.list.return_value = [group] + ws.groups.get.return_value = group - account_workspaces = AccountWorkspaces(account_config, workspace_client) + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups() acc_client.groups.create.assert_not_called() def test_create_acc_groups_should_create_groups_accross_workspaces(mocker): - account_config = AccountConfig( - connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") - ) - # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 - acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") - acc_client.config = account_config.to_databricks_config() - - account_config.to_account_client = lambda: acc_client - account_config.include_workspace_names = ["foo", "bar"] + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - mock1 = create_autospec(WorkspaceClient) - mock2 = create_autospec(WorkspaceClient) + ws1 = create_autospec(WorkspaceClient) + ws2 = create_autospec(WorkspaceClient) def workspace_client(host, product, **kwargs) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": - return mock1 + return ws1 else: - return mock2 + return ws2 - im = create_autospec(InstallationManager) - im.user_installations.return_value = [ - Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") - ] + group = Group(id="12", display_name="de") + group2 = Group(id="12", display_name="security_grp") - group = Group( - id="12", - display_name="de", - members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], - ) - - group2 = Group( - id="12", - display_name="security_grp", - members=[ComplexValue(display="John", value="20"), ComplexValue(display="pat", value="21")], - ) - - mock1.groups.list.return_value = [group] - mock1.groups.get.return_value = group + ws1.groups.list.return_value = [group] + ws1.groups.get.return_value = group - mock2.groups.list.return_value = [group2] - mock2.groups.get.return_value = group2 + ws2.groups.list.return_value = [group2] + ws2.groups.get.return_value = group2 - account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups() - calls = [ - call( - display_name="de", - members=[ - ComplexValue(display="test-user-1", primary=None, type=None, value="20"), - ComplexValue(display="test-user-2", primary=None, type=None, value="21"), - ], - ), - call( - display_name="security_grp", - members=[ComplexValue(display="John", value="20"), ComplexValue(display="pat", value="21")], - ), - ] - acc_client.groups.create.assert_has_calls(calls) + acc_client.groups.create.assert_any_call(display_name="de") + acc_client.groups.create.assert_any_call(display_name="security_grp") def test_create_acc_groups_should_filter_groups_accross_workspaces(mocker): - account_config = AccountConfig( - connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") - ) - # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 - acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") - acc_client.config = account_config.to_databricks_config() - - account_config.to_account_client = lambda: acc_client - account_config.include_workspace_names = ["foo", "bar"] + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - mock1 = create_autospec(WorkspaceClient) - mock2 = create_autospec(WorkspaceClient) + ws1 = create_autospec(WorkspaceClient) + ws2 = create_autospec(WorkspaceClient) def workspace_client(host, product, **kwargs) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": - return mock1 + return ws1 else: - return mock2 - - im = create_autospec(InstallationManager) - im.user_installations.return_value = [ - Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") - ] + return ws2 group = Group( id="12", @@ -254,53 +201,47 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) - mock1.groups.list.return_value = [group] - mock1.groups.get.return_value = group + ws1.groups.list.return_value = [group] + ws1.groups.get.return_value = group - mock2.groups.list.return_value = [group] - mock2.groups.get.return_value = group + ws2.groups.list.return_value = [group] + ws2.groups.get.return_value = group + acc_client.groups.create.return_value = group - account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups() - acc_client.groups.create.assert_called_once_with( - display_name="de", - members=[ - ComplexValue(display="test-user-1", primary=None, type=None, value="20"), - ComplexValue(display="test-user-2", primary=None, type=None, value="21"), + acc_client.groups.create.assert_called_once_with(display_name="de") + acc_client.groups.patch.assert_called_once_with( + "12", + operations=[ + iam.Patch( + op=iam.PatchOp.ADD, + path='members', + value=[{'display': 'test-user-1', 'value': '20'}, {'display': 'test-user-2', 'value': '21'}], + ) ], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], ) def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_but_not_same_members(mocker): - account_config = AccountConfig( - connect=ConnectConfig(host="https://accounts.cloud.databricks.com", account_id="123", token="abc") - ) - # TODO: https://github.com/databricks/databricks-sdk-py/pull/480 - acc_client = mocker.patch("databricks.sdk.AccountClient.__init__") - acc_client.config = account_config.to_databricks_config() - - account_config.to_account_client = lambda: acc_client - account_config.include_workspace_names = ["foo", "bar"] + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - mock1 = MagicMock() - mock2 = MagicMock() + ws1 = MagicMock() + ws2 = MagicMock() def workspace_client(host, product, **kwargs) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": - return mock1 + return ws1 else: - return mock2 - - im = create_autospec(InstallationManager) - im.user_installations.return_value = [ - Installation(config=WorkspaceConfig(inventory_database="ucx"), user=User(display_name="foo"), path="/Users/foo") - ] + return ws2 group = Group( id="12", @@ -313,30 +254,16 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: members=[ComplexValue(display="test-user-1", value="20")], ) - mock1.groups.list.return_value = [group] - mock1.groups.get.return_value = group - mock1.config.host = "https://host_1" + ws1.groups.list.return_value = [group] + ws1.groups.get.return_value = group + ws1.config.host = "https://host_1" - mock2.groups.list.return_value = [group_2] - mock2.groups.get.return_value = group_2 - mock2.config.host = "https://host_2" + ws2.groups.list.return_value = [group_2] + ws2.groups.get.return_value = group_2 + ws2.config.host = "https://host_2" - account_workspaces = AccountWorkspaces(account_config, workspace_client, lambda _: im) + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups() - calls = [ - call( - display_name="de", - members=[ - ComplexValue(display="test-user-1", primary=None, type=None, value="20"), - ComplexValue(display="test-user-2", primary=None, type=None, value="21"), - ], - ), - call( - display_name="https://host_2_de", - members=[ - ComplexValue(display="test-user-1", primary=None, type=None, value="20"), - ], - ), - ] - acc_client.groups.create.assert_has_calls(calls) \ No newline at end of file + acc_client.groups.create.assert_any_call(display_name="de") + acc_client.groups.create.assert_any_call(display_name="https://host_2_de") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 8b30e4bbe7..55dc9ba300 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -11,6 +11,7 @@ from databricks.labs.ucx.cli import ( alias, + create_account_level_groups, create_table_mapping, ensure_assessment_run, installations, @@ -127,6 +128,12 @@ def test_sync_workspace_info(): s.assert_called_once() +def test_create_account_groups(): + a = create_autospec(AccountClient) + create_account_level_groups(a) + a.groups.list.assert_called_with(attributes="id") + + def test_manual_workspace_info(ws): with patch("databricks.labs.ucx.account.WorkspaceInfo.manual_workspace_info", return_value=None) as m: manual_workspace_info(ws) From df0c230539373dbe15fce372a4a0eb6962af79da Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 11:05:42 +0100 Subject: [PATCH 07/25] reformatting issues --- tests/integration/assessment/test_azure.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/assessment/test_azure.py b/tests/integration/assessment/test_azure.py index 4afe03c0a8..5cc56bc939 100644 --- a/tests/integration/assessment/test_azure.py +++ b/tests/integration/assessment/test_azure.py @@ -110,9 +110,9 @@ def test_spn_crawler_with_available_secrets( _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.id.SA1.dfs.core.windows.net"] = ( "{" + (f"{{secrets/{secret_scope}/{secret_key}}}") + "}" ) - _pipeline_conf_with_avlbl_secret[ - "fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net" - ] = "https://login.microsoftonline.com/dummy_tenant/oauth2/token" + _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net"] = ( + "https://login.microsoftonline.com/dummy_tenant/oauth2/token" + ) make_job() make_pipeline(configuration=_pipeline_conf_with_avlbl_secret) spn_crawler = AzureServicePrincipalCrawler(ws=ws, sbe=sql_backend, schema=inventory_schema) From d6398d41b5f1002afc2b133a0e0098e7fc4a4809 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 11:10:11 +0100 Subject: [PATCH 08/25] removing MagicMock references --- tests/unit/test_account.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 6f2eca0f47..656c752ca5 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -1,6 +1,6 @@ import io import json -from unittest.mock import MagicMock, create_autospec, patch +from unittest.mock import create_autospec, patch from databricks.labs.blueprint.installation import Installation, MockInstallation from databricks.labs.blueprint.tui import MockPrompts @@ -234,8 +234,8 @@ def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_ Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] - ws1 = MagicMock() - ws2 = MagicMock() + ws1 = create_autospec(WorkspaceClient) + ws2 = create_autospec(WorkspaceClient) def workspace_client(host, product, **kwargs) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": From 69a845dc1f496f4200ec09215ff0e57e17c2d5cb Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 11:43:36 +0100 Subject: [PATCH 09/25] adding command to cli --- labs.yml | 3 +++ src/databricks/labs/ucx/account.py | 9 ++++++--- src/databricks/labs/ucx/cli.py | 2 +- tests/unit/test_cli.py | 4 ++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/labs.yml b/labs.yml index c1f40c29f6..8791a98f5e 100644 --- a/labs.yml +++ b/labs.yml @@ -44,6 +44,9 @@ commands: - name: create-table-mapping description: create initial table mapping for review + - name: create-account-groups + description: creates account level groups for all groups in all workspaces. + - name: ensure-assessment-run description: ensure the assessment job was run on a workspace diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 84f77774fe..72402bc0e0 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -106,7 +106,8 @@ def _get_valid_workspaces_groups(self) -> dict[str, Group]: for client in self.workspace_clients(): ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: - assert group_id.id is not None + if not group_id.id: + continue full_workspace_group = client.groups.get(group_id.id) group_name = full_workspace_group.display_name @@ -122,7 +123,8 @@ def _get_valid_workspaces_groups(self) -> dict[str, Group]: logger.info(f"Workspace group {group_name} already found, ignoring") else: logger.info(f"Found new group {group_name}") - assert group_name is not None + if not group_name: + continue all_workspaces_groups[group_name] = full_workspace_group return all_workspaces_groups @@ -134,7 +136,8 @@ def _has_not_same_members(self, group_1: Group, group_2: Group) -> bool: def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: acc_groups = {} for acc_grp_id in self._ac.groups.list(attributes="id"): - assert acc_grp_id.id is not None + if not acc_grp_id.id: + continue full_account_group = self._ac.groups.get(acc_grp_id.id) acc_groups[full_account_group.display_name] = full_account_group.members return acc_groups diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index e8db87552e..c5d3974109 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -90,7 +90,7 @@ def sync_workspace_info(a: AccountClient): @ucx.command(is_account=True) -def create_account_level_groups(a: AccountClient): +def create_account_groups(a: AccountClient): """ Crawl all workspaces, and create account level groups if a WS local group is not present in the account. The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 55dc9ba300..1128dfa769 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -11,7 +11,7 @@ from databricks.labs.ucx.cli import ( alias, - create_account_level_groups, + create_account_groups, create_table_mapping, ensure_assessment_run, installations, @@ -130,7 +130,7 @@ def test_sync_workspace_info(): def test_create_account_groups(): a = create_autospec(AccountClient) - create_account_level_groups(a) + create_account_groups(a) a.groups.list.assert_called_with(attributes="id") From a608e58399d0ea828f1a68fe7016526ef6ce3d6e Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 11:51:32 +0100 Subject: [PATCH 10/25] fix --- labs.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/labs.yml b/labs.yml index 8791a98f5e..a611fb58c4 100644 --- a/labs.yml +++ b/labs.yml @@ -44,9 +44,6 @@ commands: - name: create-table-mapping description: create initial table mapping for review - - name: create-account-groups - description: creates account level groups for all groups in all workspaces. - - name: ensure-assessment-run description: ensure the assessment job was run on a workspace @@ -121,3 +118,6 @@ commands: flags: - name: aws-profile description: AWS Profile to use for authentication + + - name: create-account-groups + description: Creates account level groups for all groups in all workspaces. From 809b6e30fc5471c4d973e85e401d008823d39c2d Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 2 Feb 2024 22:42:04 +0100 Subject: [PATCH 11/25] pr returns --- src/databricks/labs/ucx/account.py | 56 +++++++++++++++++------------- tests/unit/test_account.py | 8 ++--- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 72402bc0e0..eae5d21c05 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -6,8 +6,7 @@ from databricks.labs.blueprint.tui import Prompts from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.errors import NotFound -from databricks.sdk.service import iam -from databricks.sdk.service.iam import ComplexValue, Group +from databricks.sdk.service.iam import ComplexValue, Group, Patch, PatchOp, PatchSchema from databricks.sdk.service.provisioning import Workspace from databricks.labs.ucx.__about__ import __version__ @@ -78,21 +77,22 @@ def create_account_level_groups(self): for group_name, valid_group in all_valid_workspace_groups.items(): if group_name in acc_groups: logger.info(f"Group {group_name} already exist in the account, ignoring") - else: - acc_group = self._ac.groups.create(display_name=group_name) + continue + + acc_group = self._ac.groups.create(display_name=group_name) - if len(acc_group.members) > 1: - self._add_members_to_acc_group(acc_group.id, group_name, valid_group) + if len(acc_group.members) > 0: + self._add_members_to_acc_group(acc_group.id, group_name, valid_group) - logger.info(f"Group {group_name} created in the account") + logger.info(f"Group {group_name} created in the account") def _add_members_to_acc_group(self, acc_group_id: str, group_name: str, valid_group: Group): for chunk in self._chunks(valid_group.members, 20): logger.debug(f"Adding 20 members to acc group {group_name}") self._ac.groups.patch( acc_group_id, - operations=[iam.Patch(op=iam.PatchOp.ADD, path="members", value=[x.as_dict() for x in chunk])], - schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + operations=[Patch(op=PatchOp.ADD, path="members", value=[x.as_dict() for x in chunk])], + schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], ) def _chunks(self, lst, n): @@ -103,35 +103,41 @@ def _chunks(self, lst, n): def _get_valid_workspaces_groups(self) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} - for client in self.workspace_clients(): + for workspace in self._workspaces(): + client = self.client_for(workspace) ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: if not group_id.id: continue + full_workspace_group = client.groups.get(group_id.id) group_name = full_workspace_group.display_name if group_name in all_workspaces_groups: - if self._has_not_same_members(all_workspaces_groups[group_name], full_workspace_group): - logger.warning( - f"Group {group_name} does not have same amount of members " - f"in workspace {client.config.host}, it will be created with account " - f"name {client.config.host}_{group_name}" - ) - all_workspaces_groups[f"{client.config.host}_{group_name}"] = full_workspace_group - else: + if self._has_same_members(all_workspaces_groups[group_name], full_workspace_group): logger.info(f"Workspace group {group_name} already found, ignoring") - else: - logger.info(f"Found new group {group_name}") - if not group_name: continue - all_workspaces_groups[group_name] = full_workspace_group + + logger.warning( + f"Group {group_name} does not have the same amount of members " + f"in workspace {client.config.host}, it will be created with account " + f"name {workspace.workspace_name}_{group_name}" + ) + all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group + continue + + if not group_name: + continue + + logger.info(f"Found new group {group_name}") + all_workspaces_groups[group_name] = full_workspace_group + return all_workspaces_groups - def _has_not_same_members(self, group_1: Group, group_2: Group) -> bool: - ws_members_set = set([m.display for m in group_1.members] if group_1.members else []) + def _has_same_members(self, group_1: Group, group_2: Group) -> bool: + ws_members_set_1 = set([m.display for m in group_1.members] if group_1.members else []) ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) - return bool((ws_members_set - ws_members_set_2).union(ws_members_set_2 - ws_members_set)) + return not bool((ws_members_set_1 - ws_members_set_2).union(ws_members_set_2 - ws_members_set_1)) def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: acc_groups = {} diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 656c752ca5..b39b45aa17 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -230,8 +230,8 @@ def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_ acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") acc_client.workspaces.list.return_value = [ - Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), - Workspace(workspace_name="bar", workspace_id=456, workspace_status_message="Running", deployment_name="def"), + Workspace(workspace_name="ws1", workspace_id=123, workspace_status_message="Running", deployment_name="abc"), + Workspace(workspace_name="ws2", workspace_id=456, workspace_status_message="Running", deployment_name="def"), ] ws1 = create_autospec(WorkspaceClient) @@ -256,14 +256,12 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws1.groups.list.return_value = [group] ws1.groups.get.return_value = group - ws1.config.host = "https://host_1" ws2.groups.list.return_value = [group_2] ws2.groups.get.return_value = group_2 - ws2.config.host = "https://host_2" account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups() acc_client.groups.create.assert_any_call(display_name="de") - acc_client.groups.create.assert_any_call(display_name="https://host_2_de") + acc_client.groups.create.assert_any_call(display_name="ws2_de") From 213a63c2edeb92c7c172e1abb01f826517d717b5 Mon Sep 17 00:00:00 2001 From: william-conti Date: Wed, 14 Feb 2024 13:57:03 +0100 Subject: [PATCH 12/25] Adding integration test, fixing logic and PR comments --- src/databricks/labs/ucx/account.py | 64 +++++++---- src/databricks/labs/ucx/cli.py | 3 +- tests/integration/assessment/test_azure.py | 6 +- tests/integration/test_account.py | 23 ++++ tests/unit/test_account.py | 127 +++++++++++++++++++-- 5 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 tests/integration/test_account.py diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index eae5d21c05..abf5ca4124 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -38,13 +38,7 @@ def _get_cloud(self) -> str: return "aws" def client_for(self, workspace: Workspace) -> WorkspaceClient: - config = self._ac.config.as_dict() - if "databricks_cli_path" in config: - del config["databricks_cli_path"] - cloud = self._get_cloud() - # copy current config and swap with a host relevant to a workspace - config["host"] = f"https://{workspace.deployment_name}.{self._tlds[cloud]}" - return self._new_workspace_client(**config, product="ucx", product_version=__version__) + return self._ac.get_workspace_client(workspace) def workspace_clients(self) -> list[WorkspaceClient]: """ @@ -70,9 +64,9 @@ def sync_workspace_info(self): for installation in Installation.existing(ws, "ucx"): installation.save(workspaces, filename=self.SYNC_FILE_NAME) - def create_account_level_groups(self): + def create_account_level_groups(self, prompts: Prompts): acc_groups = self._get_account_groups() - all_valid_workspace_groups = self._get_valid_workspaces_groups() + all_valid_workspace_groups = self._get_valid_workspaces_groups(prompts) for group_name, valid_group in all_valid_workspace_groups.items(): if group_name in acc_groups: @@ -81,15 +75,18 @@ def create_account_level_groups(self): acc_group = self._ac.groups.create(display_name=group_name) - if len(acc_group.members) > 0: - self._add_members_to_acc_group(acc_group.id, group_name, valid_group) - + if not acc_group.id: + continue + if len(valid_group.members) > 0: + self._add_members_to_acc_group(self._ac, acc_group.id, group_name, valid_group) logger.info(f"Group {group_name} created in the account") - def _add_members_to_acc_group(self, acc_group_id: str, group_name: str, valid_group: Group): + def _add_members_to_acc_group( + self, acc_client: AccountClient, acc_group_id: str, group_name: str, valid_group: Group + ): for chunk in self._chunks(valid_group.members, 20): - logger.debug(f"Adding 20 members to acc group {group_name}") - self._ac.groups.patch( + logger.debug(f"Adding {len(chunk)} members to acc group {group_name}") + acc_client.groups.patch( acc_group_id, operations=[Patch(op=PatchOp.ADD, path="members", value=[x.as_dict() for x in chunk])], schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], @@ -100,11 +97,13 @@ def _chunks(self, lst, n): for i in range(0, len(lst), n): yield lst[i : i + n] - def _get_valid_workspaces_groups(self) -> dict[str, Group]: + def _get_valid_workspaces_groups(self, prompts: Prompts) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} for workspace in self._workspaces(): client = self.client_for(workspace) + logger.info(f"Crawling groups in workspace {client.config.host}") + ws_group_ids = client.groups.list(attributes="id") for group_id in ws_group_ids: if not group_id.id: @@ -113,18 +112,21 @@ def _get_valid_workspaces_groups(self) -> dict[str, Group]: full_workspace_group = client.groups.get(group_id.id) group_name = full_workspace_group.display_name + if self._is_group_out_of_scope(full_workspace_group): + continue + if group_name in all_workspaces_groups: if self._has_same_members(all_workspaces_groups[group_name], full_workspace_group): logger.info(f"Workspace group {group_name} already found, ignoring") continue - logger.warning( + if prompts.confirm( f"Group {group_name} does not have the same amount of members " - f"in workspace {client.config.host}, it will be created with account " - f"name {workspace.workspace_name}_{group_name}" - ) - all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group - continue + f"in workspace {client.config.host} than previous workspaces which contains the same group name," + f"it will be created at the account with name : {workspace.workspace_name}_{group_name}" + ): + all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group + continue if not group_name: continue @@ -132,20 +134,38 @@ def _get_valid_workspaces_groups(self) -> dict[str, Group]: logger.info(f"Found new group {group_name}") all_workspaces_groups[group_name] = full_workspace_group + logger.info(f"Found a total of {len(all_workspaces_groups)} groups to migrate to the account") + return all_workspaces_groups + def _is_group_out_of_scope(self, group: Group) -> bool: + if group.display_name in ["users", "admins", "account users"]: + logger.debug(f"Group {group.display_name} is a system group, ignoring") + return True + meta = group.meta + if not meta: + return False + if meta.resource_type != "WorkspaceGroup": + logger.debug(f"Group {group.display_name} is an account group, ignoring") + return True + return False + def _has_same_members(self, group_1: Group, group_2: Group) -> bool: ws_members_set_1 = set([m.display for m in group_1.members] if group_1.members else []) ws_members_set_2 = set([m.display for m in group_2.members] if group_2.members else []) return not bool((ws_members_set_1 - ws_members_set_2).union(ws_members_set_2 - ws_members_set_1)) def _get_account_groups(self) -> dict[str | None, list[ComplexValue] | None]: + logger.debug("Listing groups in account") acc_groups = {} for acc_grp_id in self._ac.groups.list(attributes="id"): if not acc_grp_id.id: continue full_account_group = self._ac.groups.get(acc_grp_id.id) + logger.debug(f"Found account group {acc_grp_id.display_name}") acc_groups[full_account_group.display_name] = full_account_group.members + + logger.info(f"{len(acc_groups)} account groups found") return acc_groups diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 67ff37c88f..7ac2ba7492 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -102,8 +102,9 @@ def create_account_groups(a: AccountClient): account """ logger.info(f"Account ID: {a.config.account_id}") + prompts = Prompts() workspaces = AccountWorkspaces(a) - workspaces.create_account_level_groups() + workspaces.create_account_level_groups(prompts) @ucx.command diff --git a/tests/integration/assessment/test_azure.py b/tests/integration/assessment/test_azure.py index 5cc56bc939..4afe03c0a8 100644 --- a/tests/integration/assessment/test_azure.py +++ b/tests/integration/assessment/test_azure.py @@ -110,9 +110,9 @@ def test_spn_crawler_with_available_secrets( _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.id.SA1.dfs.core.windows.net"] = ( "{" + (f"{{secrets/{secret_scope}/{secret_key}}}") + "}" ) - _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net"] = ( - "https://login.microsoftonline.com/dummy_tenant/oauth2/token" - ) + _pipeline_conf_with_avlbl_secret[ + "fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net" + ] = "https://login.microsoftonline.com/dummy_tenant/oauth2/token" make_job() make_pipeline(configuration=_pipeline_conf_with_avlbl_secret) spn_crawler = AzureServicePrincipalCrawler(ws=ws, sbe=sql_backend, schema=inventory_schema) diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py new file mode 100644 index 0000000000..2e873bc341 --- /dev/null +++ b/tests/integration/test_account.py @@ -0,0 +1,23 @@ +from databricks.labs.blueprint.tui import MockPrompts + +from databricks.labs.ucx.account import AccountWorkspaces + + +def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc): + make_ucx_group("test_ucx_migrate_invalid", "test_ucx_migrate_invalid") + + members = [] + for i in range(10): + user = make_user() + members.append(user.id) + + make_group(display_name="test_ucx_migrate_valid", members=members, entitlements=["allow-cluster-create"]) + AccountWorkspaces(acc).create_account_level_groups(MockPrompts({})) + + results = [] + for grp in acc.groups.list(): + if grp.display_name == "test_ucx_migrate_valid": + results.append(grp) + + assert len(results) == 1 + diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index b39b45aa17..f83fd0f80f 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -96,7 +96,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: ws.groups.get.return_value = group acc_client.groups.create.return_value = group - account_workspaces.create_account_level_groups() + account_workspaces.create_account_level_groups(MockPrompts({})) acc_client.groups.create.assert_called_with( display_name="de", @@ -114,6 +114,113 @@ def workspace_client(**kwargs) -> WorkspaceClient: ) +def test_create_acc_groups_should_create_acc_group_with_appropriate_members(mocker): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") + ] + + ws = create_autospec(WorkspaceClient) + + def workspace_client(**kwargs) -> WorkspaceClient: + return ws + + account_workspaces = AccountWorkspaces(acc_client, workspace_client) + + group = Group( + id="12", + display_name="de", + members=[ + ComplexValue(display="test-user-1", value="1"), + ComplexValue(display="test-user-2", value="2"), + ComplexValue(display="test-user-3", value="3"), + ComplexValue(display="test-user-4", value="4"), + ComplexValue(display="test-user-5", value="5"), + ComplexValue(display="test-user-6", value="6"), + ComplexValue(display="test-user-7", value="7"), + ComplexValue(display="test-user-8", value="8"), + ComplexValue(display="test-user-9", value="9"), + ComplexValue(display="test-user-10", value="10"), + ComplexValue(display="test-user-11", value="11"), + ComplexValue(display="test-user-12", value="12"), + ComplexValue(display="test-user-13", value="13"), + ComplexValue(display="test-user-14", value="14"), + ComplexValue(display="test-user-15", value="15"), + ComplexValue(display="test-user-16", value="16"), + ComplexValue(display="test-user-17", value="17"), + ComplexValue(display="test-user-18", value="18"), + ComplexValue(display="test-user-19", value="19"), + ComplexValue(display="test-user-20", value="20"), + ComplexValue(display="test-user-21", value="21"), + ComplexValue(display="test-user-22", value="22"), + ComplexValue(display="test-user-23", value="23"), + ComplexValue(display="test-user-24", value="24"), + ComplexValue(display="test-user-25", value="25"), + ], + ) + + ws.groups.list.return_value = [group] + ws.groups.get.return_value = group + acc_client.groups.create.return_value = group + + account_workspaces.create_account_level_groups(MockPrompts({})) + + acc_client.groups.create.assert_called_with( + display_name="de", + ) + acc_client.groups.patch.assert_any_call( + "12", + operations=[ + iam.Patch( + op=iam.PatchOp.ADD, + path='members', + value=[ + {'display': 'test-user-1', 'value': '1'}, + {'display': 'test-user-2', 'value': '2'}, + {'display': 'test-user-3', 'value': '3'}, + {'display': 'test-user-4', 'value': '4'}, + {'display': 'test-user-5', 'value': '5'}, + {'display': 'test-user-6', 'value': '6'}, + {'display': 'test-user-7', 'value': '7'}, + {'display': 'test-user-8', 'value': '8'}, + {'display': 'test-user-9', 'value': '9'}, + {'display': 'test-user-10', 'value': '10'}, + {'display': 'test-user-11', 'value': '11'}, + {'display': 'test-user-12', 'value': '12'}, + {'display': 'test-user-13', 'value': '13'}, + {'display': 'test-user-14', 'value': '14'}, + {'display': 'test-user-15', 'value': '15'}, + {'display': 'test-user-16', 'value': '16'}, + {'display': 'test-user-17', 'value': '17'}, + {'display': 'test-user-18', 'value': '18'}, + {'display': 'test-user-19', 'value': '19'}, + {'display': 'test-user-20', 'value': '20'}, + ], + ) + ], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + ) + acc_client.groups.patch.assert_any_call( + "12", + operations=[ + iam.Patch( + op=iam.PatchOp.ADD, + path='members', + value=[ + {'display': 'test-user-21', 'value': '21'}, + {'display': 'test-user-22', 'value': '22'}, + {'display': 'test-user-23', 'value': '23'}, + {'display': 'test-user-24', 'value': '24'}, + {'display': 'test-user-25', 'value': '25'}, + ], + ) + ], + schemas=[iam.PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], + ) + + def test_create_acc_groups_should_not_create_group_if_exists_in_account(mocker): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -138,7 +245,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: ws.groups.get.return_value = group account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups() + account_workspaces.create_account_level_groups(MockPrompts({})) acc_client.groups.create.assert_not_called() @@ -161,8 +268,8 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: else: return ws2 - group = Group(id="12", display_name="de") - group2 = Group(id="12", display_name="security_grp") + group = Group(id="12", display_name="de", members=[]) + group2 = Group(id="12", display_name="security_grp", members=[]) ws1.groups.list.return_value = [group] ws1.groups.get.return_value = group @@ -171,7 +278,7 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws2.groups.get.return_value = group2 account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups() + account_workspaces.create_account_level_groups(MockPrompts({})) acc_client.groups.create.assert_any_call(display_name="de") acc_client.groups.create.assert_any_call(display_name="security_grp") @@ -209,7 +316,7 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: acc_client.groups.create.return_value = group account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups() + account_workspaces.create_account_level_groups(MockPrompts({})) acc_client.groups.create.assert_called_once_with(display_name="de") acc_client.groups.patch.assert_called_once_with( @@ -261,7 +368,13 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws2.groups.get.return_value = group_2 account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups() + account_workspaces.create_account_level_groups( + MockPrompts( + { + r'Group de does not have the same amount of members in workspace ': 'yes', + } + ) + ) acc_client.groups.create.assert_any_call(display_name="de") acc_client.groups.create.assert_any_call(display_name="ws2_de") From d0b7889ae4893d85d72c4f55fdd835b45b44cf6e Mon Sep 17 00:00:00 2001 From: william-conti Date: Wed, 14 Feb 2024 15:37:55 +0100 Subject: [PATCH 13/25] formatting --- src/databricks/labs/ucx/account.py | 6 ++---- tests/integration/test_account.py | 1 - tests/unit/test_account.py | 30 +++++++++++++++++++++++++++--- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index abf5ca4124..38723c3ce6 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -9,8 +9,6 @@ from databricks.sdk.service.iam import ComplexValue, Group, Patch, PatchOp, PatchSchema from databricks.sdk.service.provisioning import Workspace -from databricks.labs.ucx.__about__ import __version__ - logger = logging.getLogger(__name__) @@ -75,7 +73,7 @@ def create_account_level_groups(self, prompts: Prompts): acc_group = self._ac.groups.create(display_name=group_name) - if not acc_group.id: + if not valid_group.members or not acc_group.id: continue if len(valid_group.members) > 0: self._add_members_to_acc_group(self._ac, acc_group.id, group_name, valid_group) @@ -139,7 +137,7 @@ def _get_valid_workspaces_groups(self, prompts: Prompts) -> dict[str, Group]: return all_workspaces_groups def _is_group_out_of_scope(self, group: Group) -> bool: - if group.display_name in ["users", "admins", "account users"]: + if group.display_name in {"users", "admins", "account users"}: logger.debug(f"Group {group.display_name} is a system group, ignoring") return True meta = group.meta diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index 2e873bc341..dfe5d0f3e2 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -20,4 +20,3 @@ def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc) results.append(grp) assert len(results) == 1 - diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index f83fd0f80f..b003e1995e 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -23,6 +23,7 @@ def test_sync_workspace_info(mocker): ] ws = create_autospec(WorkspaceClient) + acc_client.get_workspace_client.return_value = ws def workspace_client(host, product, **kwargs) -> WorkspaceClient: assert host in ("https://abc.cloud.databricks.com", "https://def.cloud.databricks.com") @@ -84,8 +85,6 @@ def test_create_acc_groups_should_create_acc_group_if_no_group_found_in_account( def workspace_client(**kwargs) -> WorkspaceClient: return ws - account_workspaces = AccountWorkspaces(acc_client, workspace_client) - group = Group( id="12", display_name="de", @@ -94,8 +93,10 @@ def workspace_client(**kwargs) -> WorkspaceClient: ws.groups.list.return_value = [group] ws.groups.get.return_value = group + acc_client.get_workspace_client.return_value = ws acc_client.groups.create.return_value = group + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups(MockPrompts({})) acc_client.groups.create.assert_called_with( @@ -163,6 +164,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: ws.groups.list.return_value = [group] ws.groups.get.return_value = group + acc_client.get_workspace_client.return_value = ws acc_client.groups.create.return_value = group account_workspaces.create_account_level_groups(MockPrompts({})) @@ -243,7 +245,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: ws.groups.list.return_value = [group] ws.groups.get.return_value = group - + acc_client.get_workspace_client.return_value = ws account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups(MockPrompts({})) @@ -268,6 +270,12 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: else: return ws2 + def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + if workspace.workspace_id == 123: + return ws1 + else: + return ws2 + group = Group(id="12", display_name="de", members=[]) group2 = Group(id="12", display_name="security_grp", members=[]) @@ -277,6 +285,8 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws2.groups.list.return_value = [group2] ws2.groups.get.return_value = group2 + acc_client.get_workspace_client.side_effect = get_workspace_client + account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups(MockPrompts({})) @@ -302,6 +312,12 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: else: return ws2 + def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + if workspace.workspace_id == 123: + return ws1 + else: + return ws2 + group = Group( id="12", display_name="de", @@ -314,6 +330,7 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws2.groups.list.return_value = [group] ws2.groups.get.return_value = group acc_client.groups.create.return_value = group + acc_client.get_workspace_client.side_effect = get_workspace_client account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups(MockPrompts({})) @@ -350,6 +367,12 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: else: return ws2 + def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + if workspace.workspace_id == 123: + return ws1 + else: + return ws2 + group = Group( id="12", display_name="de", @@ -366,6 +389,7 @@ def workspace_client(host, product, **kwargs) -> WorkspaceClient: ws2.groups.list.return_value = [group_2] ws2.groups.get.return_value = group_2 + acc_client.get_workspace_client.side_effect = get_workspace_client account_workspaces = AccountWorkspaces(acc_client, workspace_client) account_workspaces.create_account_level_groups( From f70a5eaff34ca5796d99d57c20d69d744fd79b78 Mon Sep 17 00:00:00 2001 From: william-conti Date: Wed, 14 Feb 2024 15:39:05 +0100 Subject: [PATCH 14/25] fix --- tests/integration/assessment/test_azure.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integration/assessment/test_azure.py b/tests/integration/assessment/test_azure.py index 4afe03c0a8..5cc56bc939 100644 --- a/tests/integration/assessment/test_azure.py +++ b/tests/integration/assessment/test_azure.py @@ -110,9 +110,9 @@ def test_spn_crawler_with_available_secrets( _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.id.SA1.dfs.core.windows.net"] = ( "{" + (f"{{secrets/{secret_scope}/{secret_key}}}") + "}" ) - _pipeline_conf_with_avlbl_secret[ - "fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net" - ] = "https://login.microsoftonline.com/dummy_tenant/oauth2/token" + _pipeline_conf_with_avlbl_secret["fs.azure.account.oauth2.client.endpoint.SA1.dfs.core.windows.net"] = ( + "https://login.microsoftonline.com/dummy_tenant/oauth2/token" + ) make_job() make_pipeline(configuration=_pipeline_conf_with_avlbl_secret) spn_crawler = AzureServicePrincipalCrawler(ws=ws, sbe=sql_backend, schema=inventory_schema) From c740869efffc010d70b9104461a39129b7f87cdd Mon Sep 17 00:00:00 2001 From: william-conti Date: Thu, 15 Feb 2024 23:13:14 +0100 Subject: [PATCH 15/25] Skipping integration test, adding more unit tests --- tests/integration/test_account.py | 11 ++--- tests/unit/test_account.py | 70 +++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index dfe5d0f3e2..5bf9d5cecf 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -4,19 +4,16 @@ def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc): + # pytest.skip("Unskip when well be able to filter by workspace ID and group ID to avoid unintended side effects") make_ucx_group("test_ucx_migrate_invalid", "test_ucx_migrate_invalid") - members = [] - for i in range(10): - user = make_user() - members.append(user.id) - - make_group(display_name="test_ucx_migrate_valid", members=members, entitlements=["allow-cluster-create"]) + make_group(display_name="regular_group", members=[make_user().id]) AccountWorkspaces(acc).create_account_level_groups(MockPrompts({})) results = [] for grp in acc.groups.list(): - if grp.display_name == "test_ucx_migrate_valid": + if grp.display_name in ["regular_group"]: results.append(grp) + acc.groups.delete(grp.id) # Avoids flakiness for future runs assert len(results) == 1 diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index b003e1995e..ea5f5e0367 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -7,7 +7,7 @@ from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.config import Config from databricks.sdk.service import iam -from databricks.sdk.service.iam import ComplexValue, Group, User +from databricks.sdk.service.iam import ComplexValue, Group, ResourceMeta, User from databricks.sdk.service.provisioning import Workspace from databricks.labs.ucx.account import AccountWorkspaces, WorkspaceInfo @@ -72,7 +72,7 @@ def test_manual_workspace_info(mocker): ws.workspace.upload.assert_called() -def test_create_acc_groups_should_create_acc_group_if_no_group_found_in_account(mocker): +def test_create_acc_groups_should_create_acc_group_if_no_group_found_in_account(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -115,7 +115,63 @@ def workspace_client(**kwargs) -> WorkspaceClient: ) -def test_create_acc_groups_should_create_acc_group_with_appropriate_members(mocker): +def test_create_acc_groups_should_filter_system_groups(): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") + ] + + ws = create_autospec(WorkspaceClient) + + def workspace_client(**kwargs) -> WorkspaceClient: + return ws + + group = Group( + id="12", + display_name="admins", + members=[], + ) + + ws.groups.list.return_value = [group] + ws.groups.get.return_value = group + acc_client.get_workspace_client.return_value = ws + acc_client.groups.create.return_value = group + + account_workspaces = AccountWorkspaces(acc_client, workspace_client) + account_workspaces.create_account_level_groups(MockPrompts({})) + + acc_client.groups.create.assert_not_called() + + +def test_create_acc_groups_should_filter_account_groups_in_workspace(): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") + + acc_client.workspaces.list.return_value = [ + Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") + ] + + ws = create_autospec(WorkspaceClient) + + def workspace_client(**kwargs) -> WorkspaceClient: + return ws + + group = Group(id="12", display_name="test_account", meta=ResourceMeta("Account")) + + ws.groups.list.return_value = [group] + ws.groups.get.return_value = group + acc_client.get_workspace_client.return_value = ws + acc_client.groups.create.return_value = group + + account_workspaces = AccountWorkspaces(acc_client, workspace_client) + account_workspaces.create_account_level_groups(MockPrompts({})) + + acc_client.groups.create.assert_not_called() + + +def test_create_acc_groups_should_create_acc_group_with_appropriate_members(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -223,7 +279,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: ) -def test_create_acc_groups_should_not_create_group_if_exists_in_account(mocker): +def test_create_acc_groups_should_not_create_group_if_exists_in_account(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -252,7 +308,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: acc_client.groups.create.assert_not_called() -def test_create_acc_groups_should_create_groups_accross_workspaces(mocker): +def test_create_acc_groups_should_create_groups_accross_workspaces(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -294,7 +350,7 @@ def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: acc_client.groups.create.assert_any_call(display_name="security_grp") -def test_create_acc_groups_should_filter_groups_accross_workspaces(mocker): +def test_create_acc_groups_should_filter_groups_accross_workspaces(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -349,7 +405,7 @@ def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: ) -def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_but_not_same_members(mocker): +def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_but_not_same_members(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") From 8d8cc2985f0fad1c09cb5681efc4a1983656d107 Mon Sep 17 00:00:00 2001 From: william-conti Date: Fri, 16 Feb 2024 00:00:48 +0100 Subject: [PATCH 16/25] Adding coverage --- tests/unit/test_account.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index ea5f5e0367..47494193c2 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -89,9 +89,14 @@ def workspace_client(**kwargs) -> WorkspaceClient: id="12", display_name="de", members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], + meta=ResourceMeta("WorkspaceGroup"), + ) + group_2 = Group( + display_name="no_id", + members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) - ws.groups.list.return_value = [group] + ws.groups.list.return_value = [group, group_2] ws.groups.get.return_value = group acc_client.get_workspace_client.return_value = ws acc_client.groups.create.return_value = group @@ -288,7 +293,8 @@ def test_create_acc_groups_should_not_create_group_if_exists_in_account(): display_name="de", members=[ComplexValue(display="test-user-1", value="20"), ComplexValue(display="test-user-2", value="21")], ) - acc_client.groups.list.return_value = [group] + group_2 = Group(display_name="de_invalid") + acc_client.groups.list.return_value = [group, group_2] acc_client.groups.get.return_value = group acc_client.workspaces.list.return_value = [ Workspace(workspace_name="foo", workspace_id=123, workspace_status_message="Running", deployment_name="abc") From 5624f896c3c9fc003e65aab306b39600aabc839e Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 09:58:43 +0100 Subject: [PATCH 17/25] add possibility to filter by workspace id --- labs.yml | 7 ++++++- src/databricks/labs/ucx/account.py | 29 ++++++++++++++++++++++++++--- src/databricks/labs/ucx/cli.py | 4 ++-- tests/integration/test_account.py | 5 ++--- tests/unit/test_account.py | 11 ++++++----- tests/unit/test_cli.py | 5 +++-- 6 files changed, 45 insertions(+), 16 deletions(-) diff --git a/labs.yml b/labs.yml index 20119890af..1499391c6e 100644 --- a/labs.yml +++ b/labs.yml @@ -128,4 +128,9 @@ commands: - name: create-account-groups is_account_level: true - description: Creates account level groups for all groups in all workspaces. + description: | + Creates account level groups for all groups in workspaces provided in workspace_ids. + If workspace_ids is not provided, it will use all workspaces present in the account. + flags: + - name: workspace_ids + description: List of workspace IDs to create account groups from. \ No newline at end of file diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 38723c3ce6..2089b17e69 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -62,9 +62,10 @@ def sync_workspace_info(self): for installation in Installation.existing(ws, "ucx"): installation.save(workspaces, filename=self.SYNC_FILE_NAME) - def create_account_level_groups(self, prompts: Prompts): + def create_account_level_groups(self, prompts: Prompts, workspace_ids: list[int] | None = None): acc_groups = self._get_account_groups() - all_valid_workspace_groups = self._get_valid_workspaces_groups(prompts) + workspace_ids = self._get_valid_workspaces_ids(workspace_ids) + all_valid_workspace_groups = self._get_valid_workspaces_groups(prompts, workspace_ids) for group_name, valid_group in all_valid_workspace_groups.items(): if group_name in acc_groups: @@ -79,6 +80,26 @@ def create_account_level_groups(self, prompts: Prompts): self._add_members_to_acc_group(self._ac, acc_group.id, group_name, valid_group) logger.info(f"Group {group_name} created in the account") + def _get_valid_workspaces_ids(self, workspace_ids: list[int] | None = None) -> list[int]: + if not workspace_ids: + logger.info("No workspace ids provided, using current workspace instead") + return [self._new_workspace_client().get_workspace_id()] + + all_workspace_ids = [workspace.workspace_id for workspace in self._workspaces()] + + valid_workspace_ids = [] + for workspace_id in workspace_ids: + if workspace_id in all_workspace_ids: + valid_workspace_ids.append(workspace_id) + else: + logger.info(f"Workspace id {workspace_id} not found on the account") + + if not valid_workspace_ids: + raise ValueError("No workspace ids provided in the configuration found in the account") + + logger.info("Creating account groups for workspaces IDs : " + ','.join(str(x) for x in valid_workspace_ids)) + return valid_workspace_ids + def _add_members_to_acc_group( self, acc_client: AccountClient, acc_group_id: str, group_name: str, valid_group: Group ): @@ -95,10 +116,12 @@ def _chunks(self, lst, n): for i in range(0, len(lst), n): yield lst[i : i + n] - def _get_valid_workspaces_groups(self, prompts: Prompts) -> dict[str, Group]: + def _get_valid_workspaces_groups(self, prompts: Prompts, workspace_ids: list[int]) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} for workspace in self._workspaces(): + if workspace.workspace_id not in workspace_ids: + continue client = self.client_for(workspace) logger.info(f"Crawling groups in workspace {client.config.host}") diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index 7ac2ba7492..9d4746aafc 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -90,7 +90,7 @@ def sync_workspace_info(a: AccountClient): @ucx.command(is_account=True) -def create_account_groups(a: AccountClient): +def create_account_groups(a: AccountClient, workspace_ids: list[int] | None = None): """ Crawl all workspaces, and create account level groups if a WS local group is not present in the account. The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. @@ -104,7 +104,7 @@ def create_account_groups(a: AccountClient): logger.info(f"Account ID: {a.config.account_id}") prompts = Prompts() workspaces = AccountWorkspaces(a) - workspaces.create_account_level_groups(prompts) + workspaces.create_account_level_groups(prompts, workspace_ids) @ucx.command diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index 5bf9d5cecf..e3f3e4ae06 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -3,12 +3,11 @@ from databricks.labs.ucx.account import AccountWorkspaces -def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc): - # pytest.skip("Unskip when well be able to filter by workspace ID and group ID to avoid unintended side effects") +def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc, ws): make_ucx_group("test_ucx_migrate_invalid", "test_ucx_migrate_invalid") make_group(display_name="regular_group", members=[make_user().id]) - AccountWorkspaces(acc).create_account_level_groups(MockPrompts({})) + AccountWorkspaces(acc).create_account_level_groups(MockPrompts({}), [ws.get_workspace_id()]) results = [] for grp in acc.groups.list(): diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 47494193c2..3c78a13633 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -102,7 +102,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: acc_client.groups.create.return_value = group account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123]) acc_client.groups.create.assert_called_with( display_name="de", @@ -228,7 +228,7 @@ def workspace_client(**kwargs) -> WorkspaceClient: acc_client.get_workspace_client.return_value = ws acc_client.groups.create.return_value = group - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123]) acc_client.groups.create.assert_called_with( display_name="de", @@ -350,7 +350,7 @@ def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: acc_client.get_workspace_client.side_effect = get_workspace_client account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123, 456]) acc_client.groups.create.assert_any_call(display_name="de") acc_client.groups.create.assert_any_call(display_name="security_grp") @@ -395,7 +395,7 @@ def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: acc_client.get_workspace_client.side_effect = get_workspace_client account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123, 456]) acc_client.groups.create.assert_called_once_with(display_name="de") acc_client.groups.patch.assert_called_once_with( @@ -459,7 +459,8 @@ def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: { r'Group de does not have the same amount of members in workspace ': 'yes', } - ) + ), + [123, 456], ) acc_client.groups.create.assert_any_call(display_name="de") diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 99e5f64a2d..3ea1c03839 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -131,8 +131,9 @@ def test_sync_workspace_info(): def test_create_account_groups(): a = create_autospec(AccountClient) - create_account_groups(a) - a.groups.list.assert_called_with(attributes="id") + with (patch("databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None) as s,): + create_account_groups(a) + a.groups.list.assert_called_with(attributes="id") def test_manual_workspace_info(ws): From 2eba4345cdbf9901e349b10c141227466945e970 Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 10:03:04 +0100 Subject: [PATCH 18/25] self review --- labs.yml | 16 ---------------- src/databricks/labs/ucx/cli.py | 2 +- src/databricks/labs/ucx/framework/dashboards.py | 2 +- 3 files changed, 2 insertions(+), 18 deletions(-) diff --git a/labs.yml b/labs.yml index b686310e3f..d1093f07e9 100644 --- a/labs.yml +++ b/labs.yml @@ -115,22 +115,6 @@ commands: {{range .}}{{.wf_group_name}}\t{{.wf_group_members_count}}\t{{.acc_group_name}}\t{{.acc_group_members_count}} {{end}} - - name: save-aws-iam-profiles - description: | - Identifies all Instance Profiles and map their access to S3 buckets. - Requires a working setup of AWS CLI. - flags: - - name: aws-profile - description: AWS Profile to use for authentication - - - name: save-uc-compatible-roles - description: | - Scan all the AWS roles that are set for UC access and produce a mapping to the S3 resources. - Requires a working setup of AWS CLI. - flags: - - name: aws-profile - description: AWS Profile to use for authentication - - name: migrate_credentials description: Migrate credentials for storage access to UC storage credential diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index f03fc98336..e39e11aa48 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -94,7 +94,7 @@ def sync_workspace_info(a: AccountClient): def create_account_groups(a: AccountClient, workspace_ids: list[int] | None = None): """ Crawl all workspaces, and create account level groups if a WS local group is not present in the account. - The feature is not configurable, meaning that it fetches all workspaces groups and all account groups. + The can be configured for multiple workspace IDs or all workspaces configured in the account. The following scenarios are supported, if a group X: - Exist in workspaces A,B,C and it has same members in there, it will be created in the account diff --git a/src/databricks/labs/ucx/framework/dashboards.py b/src/databricks/labs/ucx/framework/dashboards.py index 8723a6bd5a..cefa9a91ec 100644 --- a/src/databricks/labs/ucx/framework/dashboards.py +++ b/src/databricks/labs/ucx/framework/dashboards.py @@ -47,7 +47,7 @@ def viz_args(self) -> dict: class VizColumn: # pylint: disable=too-many-instance-attributes) name: str title: str - type: str = "string" # noqa: A003 + type: str = "string" imageUrlTemplate: str = "{{ @ }}" # noqa: N815 imageTitleTemplate: str = "{{ @ }}" # noqa: N815 linkUrlTemplate: str = "{{ @ }}" # noqa: N815 From 71df1bee070010c69b76a923f2b843e1f1e3d325 Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 10:09:00 +0100 Subject: [PATCH 19/25] adding docs --- README.md | 16 ++++++++++++++++ src/databricks/labs/ucx/cli.py | 7 ++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 47f6b450e0..d3c86bda0f 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ See [contributing instructions](CONTRIBUTING.md) to help improve this project. * [Producing table mapping](#producing-table-mapping) * [Synchronising UCX configurations](#synchronising-ucx-configurations) * [Validating group membership](#validating-group-membership) + * [Creating account groups](#creating-account-groups) * [Star History](#star-history) * [Project Support](#project-support) @@ -197,6 +198,21 @@ Use to validate workspace-level & account-level groups to identify any discrepan databricks labs ucx validate-groups-membership ``` +### Creating account groups +Crawl all workspaces configured in workspace_ids, then creates account level groups if a WS local group is not present + in the account. +If workspace_ids is not specified, it will create account groups for all workspaces configured in the account. + +The following scenarios are supported, if a group X: +- Exist in workspaces A,B,C and it has same members in there, it will be created in the account +- Exist in workspaces A,B but not in C, it will be created in the account +- Exist in workspaces A,B,C. It has same members in A,B, but not in C. Then, X and C_X will be created in the +account + +```commandline +databricks labs ucx create-account-groups --workspace_ids +``` + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=databrickslabs/ucx&type=Date)](https://star-history.com/#databrickslabs/ucx) diff --git a/src/databricks/labs/ucx/cli.py b/src/databricks/labs/ucx/cli.py index e39e11aa48..da07aab4f4 100644 --- a/src/databricks/labs/ucx/cli.py +++ b/src/databricks/labs/ucx/cli.py @@ -93,13 +93,14 @@ def sync_workspace_info(a: AccountClient): @ucx.command(is_account=True) def create_account_groups(a: AccountClient, workspace_ids: list[int] | None = None): """ - Crawl all workspaces, and create account level groups if a WS local group is not present in the account. - The can be configured for multiple workspace IDs or all workspaces configured in the account. + Crawl all workspaces configured in workspace_ids, then creates account level groups if a WS local group is not present + in the account. + If workspace_ids is not specified, it will create account groups for all workspaces configured in the account. The following scenarios are supported, if a group X: - Exist in workspaces A,B,C and it has same members in there, it will be created in the account - Exist in workspaces A,B but not in C, it will be created in the account - - Exist in workspaces A,B,C and it has same members in A,B, but not in C, then, X and C_X will be created in the + - Exist in workspaces A,B,C. It has same members in A,B, but not in C. Then, X and C_X will be created in the account """ logger.info(f"Account ID: {a.config.account_id}") From 3c27160128ce3b061ba114576b4566e4891141d0 Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 17:45:14 +0100 Subject: [PATCH 20/25] fix format and pylint --- src/databricks/labs/ucx/account.py | 9 +++-- src/databricks/labs/ucx/azure/credentials.py | 2 - tests/integration/test_account.py | 2 +- tests/unit/test_account.py | 40 +++++++++----------- tests/unit/test_cli.py | 2 +- 5 files changed, 24 insertions(+), 31 deletions(-) diff --git a/src/databricks/labs/ucx/account.py b/src/databricks/labs/ucx/account.py index 7304eb39c3..62f6eb7277 100644 --- a/src/databricks/labs/ucx/account.py +++ b/src/databricks/labs/ucx/account.py @@ -97,7 +97,8 @@ def _get_valid_workspaces_ids(self, workspace_ids: list[int] | None = None) -> l if not valid_workspace_ids: raise ValueError("No workspace ids provided in the configuration found in the account") - logger.info("Creating account groups for workspaces IDs : " + ','.join(str(x) for x in valid_workspace_ids)) + workspace_ids_str = ','.join(str(x) for x in valid_workspace_ids) + logger.info(f"Creating account groups for workspaces IDs : {workspace_ids_str}") return valid_workspace_ids def _add_members_to_acc_group( @@ -111,10 +112,10 @@ def _add_members_to_acc_group( schemas=[PatchSchema.URN_IETF_PARAMS_SCIM_API_MESSAGES_2_0_PATCH_OP], ) - def _chunks(self, lst, n): + def _chunks(self, lst, chunk_size): """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] + for i in range(0, len(lst), chunk_size): + yield lst[i : i + chunk_size] def _get_valid_workspaces_groups(self, prompts: Prompts, workspace_ids: list[int]) -> dict[str, Group]: all_workspaces_groups: dict[str, Group] = {} diff --git a/src/databricks/labs/ucx/azure/credentials.py b/src/databricks/labs/ucx/azure/credentials.py index c1182587c0..01df841a31 100644 --- a/src/databricks/labs/ucx/azure/credentials.py +++ b/src/databricks/labs/ucx/azure/credentials.py @@ -140,7 +140,6 @@ def validate(self, permission_mapping: StoragePermissionMapping) -> StorageCrede class ServicePrincipalMigration(SecretsMixin): - def __init__( self, installation: Installation, @@ -247,7 +246,6 @@ def save(self, migration_results: list[StorageCredentialValidationResult]) -> st return self._installation.save(migration_results, filename=self._output_file) def run(self, prompts: Prompts, include_names: set[str] | None = None) -> list[StorageCredentialValidationResult]: - sp_list_with_secret = self._generate_migration_list(include_names) plan_confirmed = prompts.confirm( diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index e3f3e4ae06..8ce75cd8da 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -11,7 +11,7 @@ def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc, results = [] for grp in acc.groups.list(): - if grp.display_name in ["regular_group"]: + if grp.display_name in {"regular_group"}: results.append(grp) acc.groups.delete(grp.id) # Avoids flakiness for future runs diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 82bbc2075b..6d4746d049 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -82,7 +82,7 @@ def test_create_acc_groups_should_create_acc_group_if_no_group_found_in_account( ws = create_autospec(WorkspaceClient) - def workspace_client(**kwargs) -> WorkspaceClient: + def workspace_client() -> WorkspaceClient: return ws group = Group( @@ -130,7 +130,7 @@ def test_create_acc_groups_should_filter_system_groups(): ws = create_autospec(WorkspaceClient) - def workspace_client(**kwargs) -> WorkspaceClient: + def workspace_client() -> WorkspaceClient: return ws group = Group( @@ -160,7 +160,7 @@ def test_create_acc_groups_should_filter_account_groups_in_workspace(): ws = create_autospec(WorkspaceClient) - def workspace_client(**kwargs) -> WorkspaceClient: + def workspace_client() -> WorkspaceClient: return ws group = Group(id="12", display_name="test_account", meta=ResourceMeta("Account")) @@ -186,7 +186,7 @@ def test_create_acc_groups_should_create_acc_group_with_appropriate_members(): ws = create_autospec(WorkspaceClient) - def workspace_client(**kwargs) -> WorkspaceClient: + def workspace_client() -> WorkspaceClient: return ws account_workspaces = AccountWorkspaces(acc_client, workspace_client) @@ -302,7 +302,7 @@ def test_create_acc_groups_should_not_create_group_if_exists_in_account(): ws = create_autospec(WorkspaceClient) - def workspace_client(**kwargs) -> WorkspaceClient: + def workspace_client() -> WorkspaceClient: return ws ws.groups.list.return_value = [group] @@ -326,17 +326,15 @@ def test_create_acc_groups_should_create_groups_accross_workspaces(): ws1 = create_autospec(WorkspaceClient) ws2 = create_autospec(WorkspaceClient) - def workspace_client(host, product, **kwargs) -> WorkspaceClient: + def workspace_client(host) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": return ws1 - else: - return ws2 + return ws2 - def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + def get_workspace_client(workspace) -> WorkspaceClient: if workspace.workspace_id == 123: return ws1 - else: - return ws2 + return ws2 group = Group(id="12", display_name="de", members=[]) group2 = Group(id="12", display_name="security_grp", members=[]) @@ -368,17 +366,15 @@ def test_create_acc_groups_should_filter_groups_accross_workspaces(): ws1 = create_autospec(WorkspaceClient) ws2 = create_autospec(WorkspaceClient) - def workspace_client(host, product, **kwargs) -> WorkspaceClient: + def workspace_client(host) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": return ws1 - else: - return ws2 + return ws2 - def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + def get_workspace_client(workspace) -> WorkspaceClient: if workspace.workspace_id == 123: return ws1 - else: - return ws2 + return ws2 group = Group( id="12", @@ -423,17 +419,15 @@ def test_create_acc_groups_should_create_acc_group_if_exist_in_other_workspaces_ ws1 = create_autospec(WorkspaceClient) ws2 = create_autospec(WorkspaceClient) - def workspace_client(host, product, **kwargs) -> WorkspaceClient: + def workspace_client(host) -> WorkspaceClient: if host == "https://abc.cloud.databricks.com": return ws1 - else: - return ws2 + return ws2 - def get_workspace_client(workspace, **kwargs) -> WorkspaceClient: + def get_workspace_client(workspace) -> WorkspaceClient: if workspace.workspace_id == 123: return ws1 - else: - return ws2 + return ws2 group = Group( id="12", diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 87ce2f8065..be2e95cb49 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -133,7 +133,7 @@ def test_sync_workspace_info(): def test_create_account_groups(): a = create_autospec(AccountClient) - with (patch("databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None) as s,): + with patch("databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None): create_account_groups(a) a.groups.list.assert_called_with(attributes="id") From 0bc1380e4dc1ae7da952364f5f427c93b84625fa Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 18:00:40 +0100 Subject: [PATCH 21/25] fix ci --- tests/integration/hive_metastore/verify_tacl_access.py | 1 - tests/unit/test_cli.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration/hive_metastore/verify_tacl_access.py b/tests/integration/hive_metastore/verify_tacl_access.py index 025caa3524..009e8da12b 100755 --- a/tests/integration/hive_metastore/verify_tacl_access.py +++ b/tests/integration/hive_metastore/verify_tacl_access.py @@ -6,7 +6,6 @@ def main(): - table_name = sys.argv[1] # labs-aws-simple-spn is a config profile that has SPN diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index be2e95cb49..906e8d5a63 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -133,7 +133,9 @@ def test_sync_workspace_info(): def test_create_account_groups(): a = create_autospec(AccountClient) - with patch("databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None): + with patch("databricks.sdk.WorkspaceClient.__init__", return_value=None), patch( + "databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None + ): create_account_groups(a) a.groups.list.assert_called_with(attributes="id") From 72e0395c7fd87856c50b968e1e03d0aedb03d03a Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 18:12:19 +0100 Subject: [PATCH 22/25] fixing format --- src/databricks/labs/ucx/azure/credentials.py | 2 ++ tests/integration/hive_metastore/verify_tacl_access.py | 1 + tests/unit/test_cli.py | 5 +++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/databricks/labs/ucx/azure/credentials.py b/src/databricks/labs/ucx/azure/credentials.py index 01df841a31..c1182587c0 100644 --- a/src/databricks/labs/ucx/azure/credentials.py +++ b/src/databricks/labs/ucx/azure/credentials.py @@ -140,6 +140,7 @@ def validate(self, permission_mapping: StoragePermissionMapping) -> StorageCrede class ServicePrincipalMigration(SecretsMixin): + def __init__( self, installation: Installation, @@ -246,6 +247,7 @@ def save(self, migration_results: list[StorageCredentialValidationResult]) -> st return self._installation.save(migration_results, filename=self._output_file) def run(self, prompts: Prompts, include_names: set[str] | None = None) -> list[StorageCredentialValidationResult]: + sp_list_with_secret = self._generate_migration_list(include_names) plan_confirmed = prompts.confirm( diff --git a/tests/integration/hive_metastore/verify_tacl_access.py b/tests/integration/hive_metastore/verify_tacl_access.py index 009e8da12b..025caa3524 100755 --- a/tests/integration/hive_metastore/verify_tacl_access.py +++ b/tests/integration/hive_metastore/verify_tacl_access.py @@ -6,6 +6,7 @@ def main(): + table_name = sys.argv[1] # labs-aws-simple-spn is a config profile that has SPN diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 906e8d5a63..0b559e3d88 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -133,8 +133,9 @@ def test_sync_workspace_info(): def test_create_account_groups(): a = create_autospec(AccountClient) - with patch("databricks.sdk.WorkspaceClient.__init__", return_value=None), patch( - "databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None + with ( + patch("databricks.sdk.WorkspaceClient.__init__", return_value=None), + patch("databricks.sdk.WorkspaceClient.get_workspace_id", return_value=None), ): create_account_groups(a) a.groups.list.assert_called_with(attributes="id") From 3859294c84f09db8bbaec069b629e982eb004c95 Mon Sep 17 00:00:00 2001 From: william-conti Date: Tue, 27 Feb 2024 18:25:29 +0100 Subject: [PATCH 23/25] fixing coverage --- tests/unit/test_account.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_account.py b/tests/unit/test_account.py index 6d4746d049..5ea6631a10 100644 --- a/tests/unit/test_account.py +++ b/tests/unit/test_account.py @@ -2,6 +2,7 @@ import json from unittest.mock import create_autospec, patch +import pytest from databricks.labs.blueprint.installation import Installation, MockInstallation from databricks.labs.blueprint.tui import MockPrompts from databricks.sdk import AccountClient, WorkspaceClient @@ -102,7 +103,7 @@ def workspace_client() -> WorkspaceClient: acc_client.groups.create.return_value = group account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({}), [123]) + account_workspaces.create_account_level_groups(MockPrompts({}), [123, 46]) acc_client.groups.create.assert_called_with( display_name="de", @@ -120,6 +121,27 @@ def workspace_client() -> WorkspaceClient: ) +def test_create_acc_groups_should_throw_exception(): + acc_client = create_autospec(AccountClient) + acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") + + acc_client.workspaces.list.return_value = [] + + ws = create_autospec(WorkspaceClient) + + def workspace_client() -> WorkspaceClient: + return ws + + group = Group(id="12", display_name="test_account", meta=ResourceMeta("Account")) + + acc_client.get_workspace_client.return_value = ws + acc_client.groups.create.return_value = group + + account_workspaces = AccountWorkspaces(acc_client, workspace_client) + with pytest.raises(ValueError): + account_workspaces.create_account_level_groups(MockPrompts({}), [123]) + + def test_create_acc_groups_should_filter_system_groups(): acc_client = create_autospec(AccountClient) acc_client.config = Config(host="https://accounts.cloud.databricks.com", account_id="123", token="123") @@ -145,7 +167,7 @@ def workspace_client() -> WorkspaceClient: acc_client.groups.create.return_value = group account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123]) acc_client.groups.create.assert_not_called() @@ -309,7 +331,7 @@ def workspace_client() -> WorkspaceClient: ws.groups.get.return_value = group acc_client.get_workspace_client.return_value = ws account_workspaces = AccountWorkspaces(acc_client, workspace_client) - account_workspaces.create_account_level_groups(MockPrompts({})) + account_workspaces.create_account_level_groups(MockPrompts({}), [123]) acc_client.groups.create.assert_not_called() From 4f331ef197ba95167648298bec644d2dd7539e9e Mon Sep 17 00:00:00 2001 From: william-conti Date: Wed, 28 Feb 2024 09:58:06 +0100 Subject: [PATCH 24/25] removing flakiness --- tests/integration/test_account.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index 8ce75cd8da..547536c77a 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -1,18 +1,28 @@ from databricks.labs.blueprint.tui import MockPrompts +from databricks.sdk import AccountClient +from databricks.sdk.errors import NotFound from databricks.labs.ucx.account import AccountWorkspaces -def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc, ws): - make_ucx_group("test_ucx_migrate_invalid", "test_ucx_migrate_invalid") +def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc, ws, make_random): + suffix = make_random() + make_ucx_group(f"test_ucx_migrate_invalid_{suffix}", f"test_ucx_migrate_invalid_{suffix}") - make_group(display_name="regular_group", members=[make_user().id]) + make_group(display_name=f"regular_group_{suffix}", members=[make_user().id]) AccountWorkspaces(acc).create_account_level_groups(MockPrompts({}), [ws.get_workspace_id()]) results = [] for grp in acc.groups.list(): - if grp.display_name in {"regular_group"}: + if grp.display_name in {f"regular_group_{suffix}"}: results.append(grp) - acc.groups.delete(grp.id) # Avoids flakiness for future runs + try_delete_group(acc, grp.id) # Avoids flakiness for future runs assert len(results) == 1 + + +def try_delete_group(acc: AccountClient, grp_id: str): + try: + acc.groups.delete(grp_id) + except NotFound: + pass From 735136329c7b2423e949d4cf7003a615832986da Mon Sep 17 00:00:00 2001 From: william-conti Date: Wed, 28 Feb 2024 10:17:58 +0100 Subject: [PATCH 25/25] formatting --- tests/integration/test_account.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_account.py b/tests/integration/test_account.py index 547536c77a..4c25aca9ef 100644 --- a/tests/integration/test_account.py +++ b/tests/integration/test_account.py @@ -16,7 +16,7 @@ def test_create_account_level_groups(make_ucx_group, make_group, make_user, acc, for grp in acc.groups.list(): if grp.display_name in {f"regular_group_{suffix}"}: results.append(grp) - try_delete_group(acc, grp.id) # Avoids flakiness for future runs + try_delete_group(acc, grp.id) # Avoids flakiness for future runs assert len(results) == 1