Skip to content

Commit

Permalink
Be explicit about the AuthClient API version
Browse files Browse the repository at this point in the history
  • Loading branch information
Iain-S committed May 30, 2024
1 parent 48dccb6 commit a58be53
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 45 deletions.
18 changes: 10 additions & 8 deletions status_function/status/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_role_assignments_list(auth_client: AuthClient) -> list:
Returns:
A list of role assignments.
"""
return list(auth_client.role_assignments.list())
return list(auth_client.role_assignments.list_for_subscription())


def get_role_def_dict(auth_client: AuthClient, subscription_id: str) -> dict:
Expand Down Expand Up @@ -133,7 +133,9 @@ def get_auth_client(
An authorisation client for the given subscription.
"""
return AuthClient(
credential=CREDENTIALS, subscription_id=subscription.subscription_id
credential=CREDENTIALS,
subscription_id=subscription.subscription_id,
api_version="2022-04-01",
)


Expand Down Expand Up @@ -218,7 +220,7 @@ def get_role_assignment_models(
A list of RoleAssignment objects.
"""
principal_details = []
principal = get_principal(assignment.properties.principal_id, graph_client)
principal = get_principal(assignment.principal_id, graph_client)
if principal:
if isinstance(principal, ADGroup):
principal_details.extend(get_ad_group_principals(principal, graph_client))
Expand All @@ -227,14 +229,14 @@ def get_role_assignment_models(
else:
logger.warning(
"Could not retrieve principal data for principal id %s",
assignment.properties.principal_id,
assignment.principal_id,
)
return [
models.RoleAssignment(
role_definition_id=assignment.properties.role_definition_id,
role_definition_id=assignment.role_definition_id,
role_name=role_name,
principal_id=assignment.properties.principal_id,
scope=assignment.properties.scope,
principal_id=assignment.principal_id,
scope=assignment.scope,
**x
)
for x in principal_details
Expand Down Expand Up @@ -262,7 +264,7 @@ def get_subscription_role_assignment_models(
for assignment in assignments_list:
role_assignments_models += get_role_assignment_models(
assignment,
role_def_dict.get(assignment.properties.role_definition_id, "Unknown"),
role_def_dict.get(assignment.role_definition_id, "Unknown"),
graph_client,
)
except CloudError as e:
Expand Down
114 changes: 77 additions & 37 deletions status_function/tests/test_function_app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Tests for status package."""
import logging
from datetime import datetime
from importlib import import_module
from types import SimpleNamespace
from typing import Final
from unittest import TestCase, main
from unittest.mock import MagicMock, call, patch
from uuid import UUID

import jwt
from azure.graphrbac import GraphRbacManagementClient
from azure.graphrbac.models import ADGroup, ServicePrincipal, User
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
Expand All @@ -20,7 +22,7 @@
HTTP_ADAPTER: Final = TypeAdapter(HttpUrl)
VALID_URL: Final = HTTP_ADAPTER.validate_python("https://my.org")

EXPECTED_DICT = {
EXPECTED_DICT: Final = {
"role_definition_id": str(UUID(int=10)),
"role_name": "contributor",
"principal_id": str(UUID(int=100)),
Expand All @@ -30,6 +32,16 @@
"principal_type": None,
}

API_VERSION: Final = "2022-04-01"
# e.g. v2022_04_01
API_VERSION_PACKAGE: Final = "v" + API_VERSION.replace("-", "_")
OPERATIONS_MODULE: Final = import_module(
f"azure.mgmt.authorization.{API_VERSION_PACKAGE}.operations"
)
MODELS_MODULE: Final = import_module(
f"azure.mgmt.authorization.{API_VERSION_PACKAGE}.models"
)


class TestStatus(TestCase):
"""Tests for the __init__.py file."""
Expand Down Expand Up @@ -198,18 +210,19 @@ def test_get_role_assignment_models__with_user(self) -> None:
expected_dict = EXPECTED_DICT.copy()
expected_dict.update(expected_values)

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"
with patch("status.GraphRbacManagementClient") as mock_grmc:
expected = RoleAssignment(**expected_dict)
with patch("status.get_principal") as mock_get_principal:
mock_get_principal.return_value = User()
with patch("status.get_principal_details") as mock_gud:
mock_gud.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -225,10 +238,11 @@ def test_get_role_assignment_models__with_service_principal(self) -> None:
"principal_type": ServicePrincipal,
}
expected_dict = EXPECTED_DICT.copy()
mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected_dict.update(expected_values)
Expand All @@ -238,7 +252,7 @@ def test_get_role_assignment_models__with_service_principal(self) -> None:
with patch("status.get_principal_details") as mock_spd:
mock_spd.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -260,10 +274,11 @@ def test_get_role_assignment_models__with_adgroup(self) -> None:
for i in range(2):
expected_dict_list[i].update(expected_values[i])

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected = [
Expand All @@ -274,7 +289,7 @@ def test_get_role_assignment_models__with_adgroup(self) -> None:
with patch("status.get_ad_group_principals") as mock_adgu:
mock_adgu.return_value = expected_values
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
Expand All @@ -286,23 +301,42 @@ def test_get_role_assignment_models__with_other_role_assignment(self) -> None:
"""
expected_dict = EXPECTED_DICT.copy()

mock_role_assignment = MagicMock()
mock_role_assignment.properties.role_definition_id = str(UUID(int=10))
mock_role_assignment.properties.principal_id = str(UUID(int=100))
mock_role_assignment.properties.scope = "/subscription_id/"
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/subscription_id/"

with patch("status.GraphRbacManagementClient") as mock_grmc:
expected_dict.update({"mail": None})
expected = RoleAssignment(**expected_dict)
with patch("status.get_principal") as mock_get_principal:
mock_get_principal.return_value = SimpleNamespace()
actual = status.get_role_assignment_models(
mock_role_assignment,
role_assignment,
"contributor",
mock_grmc,
)
self.assertListEqual([expected], actual)

def test_get_role_assignment_models__no_principal(self) -> None:
"""test get_role_assignment_models can handle not finding a principal"""
with patch("status.logger") as mock_logger, patch(
"status.get_principal"
) as mock_get_principal:
mock_get_principal.return_value = None
principal_id = str(UUID(int=100))
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
role_name="Contributor",
principal_id=principal_id,
)
role_assignment.scope = "/"
status.get_role_assignment_models(role_assignment, "somerole", None)
mock_logger.warning.assert_called_with(
"Could not retrieve principal data for principal id %s", principal_id
)

def test_get_subscription_role_assignment_models__no_error(self) -> None:
"""test get_subscription_role_assignment_models returns a list of
RoleAssignments"""
Expand All @@ -314,16 +348,18 @@ def test_get_subscription_role_assignment_models__no_error(self) -> None:
with patch("status.get_role_def_dict") as mock_grdd:
mock_grdd.return_value = {str(UUID(int=10)): "contributor"}
with patch("status.get_role_assignments_list") as mock_gral:
role_assignment = MODELS_MODULE.RoleAssignment()
role_assignment.scope = "/"
mock_gral.return_value = [
SimpleNamespace(
properties=SimpleNamespace(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100 + i)),
scope="/",
)
MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100 + i)),
)
for i in range(3)
]
for item in mock_gral.return_value:
item.scope = "/"

with patch("status.get_principal") as mock_principal:
mock_principal.return_value = User(display_name="Unknown")
actual = status.get_subscription_role_assignment_models(
Expand Down Expand Up @@ -367,17 +403,20 @@ def test_get_all_status(self) -> None:
)
]

# Import the role assignments class from the specific API version.

mock_role_assignments = MagicMock(
spec=OPERATIONS_MODULE.RoleAssignmentsOperations
)
with patch("status.AuthClient") as mock_auth_client:
mock_assign_func = mock_auth_client.return_value.role_assignments.list
mock_assign_func.return_value = [
SimpleNamespace(
properties=SimpleNamespace(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
scope="/",
)
)
]
mock_auth_client.return_value.role_assignments = mock_role_assignments
mock_assign_func = mock_role_assignments.list_for_subscription
role_assignment = MODELS_MODULE.RoleAssignment(
role_definition_id=str(UUID(int=10)),
principal_id=str(UUID(int=100)),
)
role_assignment.scope = "/"
mock_assign_func.return_value = [role_assignment]

mock_defs_func = mock_auth_client.return_value.role_definitions.list
mock_defs_func.return_value = [
Expand Down Expand Up @@ -427,6 +466,7 @@ def test_get_all_status(self) -> None:
mock_auth_client.assert_called_with(
credential=status.CREDENTIALS,
subscription_id=str(UUID(int=1)),
api_version=API_VERSION,
)
mock_defs_func.assert_called_with(
scope="/subscriptions/" + str(UUID(int=1))
Expand Down

0 comments on commit a58be53

Please sign in to comment.