Skip to content

Commit

Permalink
Be explicit about the AuthClient API version
Browse files Browse the repository at this point in the history
  • Loading branch information
Iain-S committed May 30, 2024
1 parent bdda55f commit 3993643
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 44 deletions.
17 changes: 10 additions & 7 deletions status_function/status/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_role_assignments_list(auth_client: AuthClient) -> list:
Returns:
A list of role assignments.
"""
return list(auth_client.role_assignments.list())
return list(auth_client.role_assignments.list_for_subscription())


def get_role_def_dict(auth_client: AuthClient, subscription_id: str) -> dict:
Expand Down Expand Up @@ -133,7 +133,9 @@ def get_auth_client(
An authorisation client for the given subscription.
"""
return AuthClient(
credential=CREDENTIALS, subscription_id=subscription.subscription_id
credential=CREDENTIALS,
subscription_id=subscription.subscription_id,
api_version="2022-04-01",
)


Expand Down Expand Up @@ -218,7 +220,7 @@ def get_role_assignment_models(
A list of RoleAssignment objects.
"""
principal_details = []
principal = get_principal(assignment.properties.principal_id, graph_client)
principal = get_principal(assignment.principal_id, graph_client)
if principal:
if isinstance(principal, ADGroup):
principal_details.extend(get_ad_group_principals(principal, graph_client))
Expand All @@ -231,10 +233,10 @@ def get_role_assignment_models(
)
return [
models.RoleAssignment(
role_definition_id=assignment.properties.role_definition_id,
role_definition_id=assignment.role_definition_id,
role_name=role_name,
principal_id=assignment.properties.principal_id,
scope=assignment.properties.scope,
principal_id=assignment.principal_id,
scope=assignment.scope,
**x
)
for x in principal_details
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_subscription_role_assignment_models(
for assignment in assignments_list:
role_assignments_models += get_role_assignment_models(
assignment,
role_def_dict.get(assignment.properties.role_definition_id, "Unknown"),
role_def_dict.get(assignment.role_definition_id, "Unknown"),
graph_client,
)
except CloudError as e:
Expand Down Expand Up @@ -314,6 +316,7 @@ def get_all_status(tenant_id: UUID) -> list[models.SubscriptionStatus]:
role_assignments=tuple(role_assignments_models),
)
)
break

logger.warning("Status data retrieved in %s.", str(datetime.now() - started_at))
return data
Expand Down
95 changes: 58 additions & 37 deletions status_function/tests/test_function_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for status package."""
import logging
from datetime import datetime
from importlib import import_module
from types import SimpleNamespace
from typing import Final
from unittest import TestCase, main
Expand All @@ -20,7 +21,7 @@
HTTP_ADAPTER: Final = TypeAdapter(HttpUrl)
VALID_URL: Final = HTTP_ADAPTER.validate_python("https://my.org")

EXPECTED_DICT = {
EXPECTED_DICT: Final = {
"role_definition_id": str(UUID(int=10)),
"role_name": "contributor",
"principal_id": str(UUID(int=100)),
Expand All @@ -30,6 +31,16 @@
"principal_type": None,
}

API_VERSION: Final = "2022-04-01"
# e.g. v2022_04_01
API_VERSION_PACKAGE: Final = "v" + API_VERSION.replace("-", "_")
OPERATIONS_MODULE: Final = import_module(
f"azure.mgmt.authorization.{API_VERSION_PACKAGE}.operations"
)
MODELS_MODULE: Final = import_module(
f"azure.mgmt.authorization.{API_VERSION_PACKAGE}.models"
)


class TestStatus(TestCase):
"""Tests for the __init__.py file."""
Expand Down Expand Up @@ -198,18 +209,19 @@ def test_get_role_assignment_models__with_user(self) -> None:
expected_dict = EXPECTED_DICT.copy()
expected_dict.update(expected_values)

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"
with patch("status.GraphRbacManagementClient") as mock_grmc:
expected = RoleAssignment(**expected_dict)
with patch("status.get_principal") as mock_get_principal:
mock_get_principal.return_value = User()
with patch("status.get_principal_details") as mock_gud:
mock_gud.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -225,10 +237,11 @@ def test_get_role_assignment_models__with_service_principal(self) -> None:
"principal_type": ServicePrincipal,
}
expected_dict = EXPECTED_DICT.copy()
mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected_dict.update(expected_values)
Expand All @@ -238,7 +251,7 @@ def test_get_role_assignment_models__with_service_principal(self) -> None:
with patch("status.get_principal_details") as mock_spd:
mock_spd.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -260,10 +273,11 @@ def test_get_role_assignment_models__with_adgroup(self) -> None:
for i in range(2):
expected_dict_list[i].update(expected_values[i])

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected = [
Expand All @@ -274,7 +288,7 @@ def test_get_role_assignment_models__with_adgroup(self) -> None:
with patch("status.get_ad_group_principals") as mock_adgu:
mock_adgu.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -286,18 +300,19 @@ def test_get_role_assignment_models__with_other_role_assignment(self) -> None:
"""
expected_dict = EXPECTED_DICT.copy()

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected_dict.update({"mail": None})
expected = RoleAssignment(**expected_dict)
with patch("status.get_principal") as mock_get_principal:
mock_get_principal.return_value = SimpleNamespace()
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -314,16 +329,18 @@ def test_get_subscription_role_assignment_models__no_error(self) -> None:
with patch("status.get_role_def_dict") as mock_grdd:
mock_grdd.return_value = {str(UUID(int=10)): "contributor"}
with patch("status.get_role_assignments_list") as mock_gral:
role_assignment = MODELS_MODULE.RoleAssignment()
role_assignment.scope = "/"
mock_gral.return_value = [
SimpleNamespace(
properties=SimpleNamespace(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100 + i)),
scope="/",
)
MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100 + i)),
)
for i in range(3)
]
for item in mock_gral.return_value:
item.scope = "/"

with patch("status.get_principal") as mock_principal:
mock_principal.return_value = User(display_name="Unknown")
actual = status.get_subscription_role_assignment_models(
Expand Down Expand Up @@ -367,17 +384,20 @@ def test_get_all_status(self) -> None:
)
]

# Import the role assignments class from the specific API version.

mock_role_assignments = MagicMock(
spec=OPERATIONS_MODULE.RoleAssignmentsOperations
)
with patch("status.AuthClient") as mock_auth_client:
mock_assign_func = mock_auth_client.return_value.role_assignments.list
mock_assign_func.return_value = [
SimpleNamespace(
properties=SimpleNamespace(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
scope="/",
)
)
]
mock_auth_client.return_value.role_assignments = mock_role_assignments
mock_assign_func = mock_role_assignments.list_for_subscription
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/"
mock_assign_func.return_value = [role_assignment]

mock_defs_func = mock_auth_client.return_value.role_definitions.list
mock_defs_func.return_value = [
Expand Down Expand Up @@ -427,6 +447,7 @@ def test_get_all_status(self) -> None:
mock_auth_client.assert_called_with(
credential=status.CREDENTIALS,
subscription_id=str(UUID(int=1)),
api_version=API_VERSION,
)
mock_defs_func.assert_called_with(
scope="/subscriptions/" + str(UUID(int=1))
Expand Down

0 comments on commit 3993643

Please sign in to comment.