From d889c9d25d9784cce4cb2afc374fc906af548052 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Thu, 19 Sep 2024 16:00:28 -0700 Subject: [PATCH 1/9] Adding more comprehensive unit tests for authentication modules Minor fixes to azure_auth.py and azure_auth_core.py --- msticpy/auth/azure_auth.py | 28 +- msticpy/auth/azure_auth_core.py | 5 +- tests/auth/test_azure_auth.py | 91 ++++++ tests/auth/test_azure_auth_core.py | 461 +++++++++++++++++++++++++++-- 4 files changed, 548 insertions(+), 37 deletions(-) create mode 100644 tests/auth/test_azure_auth.py diff --git a/msticpy/auth/azure_auth.py b/msticpy/auth/azure_auth.py index bf24492b..3202f266 100644 --- a/msticpy/auth/azure_auth.py +++ b/msticpy/auth/azure_auth.py @@ -7,11 +7,11 @@ import os from typing import List, Optional -from azure.common.exceptions import CloudError from azure.identity import DeviceCodeCredential from azure.mgmt.subscription import SubscriptionClient from .._version import VERSION +from ..common.exceptions import MsticpyAzureConnectionError # pylint: enable=unused-import from ..common.provider_settings import get_provider_settings @@ -68,7 +68,7 @@ def az_connect( Raises ------ - CloudError + MsticpyAzureConnectionError If chained token credential creation fails. See Also @@ -111,7 +111,10 @@ def az_connect( credential_scopes=[az_cloud_config.token_uri], ) if not sub_client: - raise CloudError("Could not create a Subscription client.") + raise MsticpyAzureConnectionError( + "Could not create an Azure Subscription client with credentials.", + title="Azure authentication error", + ) return credentials @@ -132,7 +135,9 @@ def az_user_connect( Returns ------- - AzCredentials + AzCredentials - Dataclass combining two types of Azure credentials: + - legacy (ADAL) credentials + - modern (MSAL) credentials """ return az_connect_core( @@ -158,14 +163,13 @@ def fallback_devicecode_creds( Returns ------- - AzCredentials - Named tuple of: - - legacy (ADAL) credentials - - modern (MSAL) credentials + AzCredentials - Dataclass combining two types of Azure credentials: + - legacy (ADAL) credentials + - modern (MSAL) credentials Raises ------ - CloudError + MsticpyAzureConnectionError If chained token credential creation fails. """ @@ -176,7 +180,11 @@ def fallback_devicecode_creds( creds = DeviceCodeCredential(authority=aad_uri, tenant_id=tenant_id) legacy_creds = CredentialWrapper(creds, resource_id=az_config.token_uri) if not creds: - raise CloudError("Could not obtain credentials.") + raise MsticpyAzureConnectionError( + f"Could not obtain credentials for tenant {tenant_id}", + "Please check your Azure configuration and try again.", + title="Azure authentication error", + ) return AzCredentials(legacy_creds, ChainedTokenCredential(creds)) # type: ignore[arg-type] diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index 8cc19789..679ddb44 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -137,7 +137,8 @@ def _build_env_client( def _build_cli_client(**kwargs) -> AzureCliCredential: """Build a credential from Azure CLI.""" - del kwargs + if tenant_id := kwargs.pop("tenant_id", None): + return AzureCliCredential(tenant_id=tenant_id) return AzureCliCredential() @@ -148,7 +149,7 @@ def _build_msi_client( ) -> ManagedIdentityCredential: """Build a credential from Managed Identity.""" msi_kwargs = kwargs.copy() - if AzureCredEnvNames.AZURE_CLIENT_ID in os.environ: + if "client_id" not in kwargs and AzureCredEnvNames.AZURE_CLIENT_ID in os.environ: msi_kwargs["client_id"] = os.environ[AzureCredEnvNames.AZURE_CLIENT_ID] return ManagedIdentityCredential( diff --git a/tests/auth/test_azure_auth.py b/tests/auth/test_azure_auth.py new file mode 100644 index 00000000..182e70af --- /dev/null +++ b/tests/auth/test_azure_auth.py @@ -0,0 +1,91 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from msticpy.auth.azure_auth import ( + az_connect, + az_user_connect, + fallback_devicecode_creds, + get_default_resource_name, +) +from msticpy.auth.azure_auth_core import AzCredentials + + +@pytest.fixture +def mock_az_credentials(): + return MagicMock(spec=AzCredentials) + + +@patch("msticpy.auth.azure_auth.os") +@patch("msticpy.auth.azure_auth.get_provider_settings") +@patch("msticpy.auth.azure_auth.az_connect_core") +@patch("msticpy.auth.azure_auth.SubscriptionClient") +def test_az_connect( + mock_sub_client, + mock_az_connect_core, + mock_get_provider_settings, + mock_os, + mock_az_credentials, +): + mock_az_credentials.modern = MagicMock() + mock_az_credentials.modern.__bool__.return_value = True + mock_az_credentials.legacy = MagicMock() + mock_az_connect_core.return_value = mock_az_credentials + mock_sub_client.return_value = MagicMock() + mock_os.environ = MagicMock() + az_cli_args = MagicMock() + az_cli_args.get.return_value = "test_value" + az_cli_args.__bool__.return_value = True + az_cli_config = MagicMock() + az_cli_config.__bool__.return_value = True + az_cli_config.args = az_cli_args + + data_provs = MagicMock(spec=dict) + data_provs.get.return_value = az_cli_config + mock_get_provider_settings.return_value = data_provs + + result = az_connect(auth_methods=["env"], tenant_id="test_tenant", silent=True) + + assert result == mock_az_credentials + mock_az_connect_core.assert_called_once_with( + auth_methods=["env"], tenant_id="test_tenant", silent=True + ) + mock_sub_client.assert_called_once() + + +@patch("msticpy.auth.azure_auth.az_connect_core") +def test_az_user_connect(mock_az_connect_core, mock_az_credentials): + mock_az_connect_core.return_value = mock_az_credentials + + result = az_user_connect(tenant_id="test_tenant", silent=True) + + assert result == mock_az_credentials + mock_az_connect_core.assert_called_once_with( + auth_methods=["cli", "interactive"], tenant_id="test_tenant", silent=True + ) + + +@patch("msticpy.auth.azure_auth.AzureCloudConfig") +@patch("msticpy.auth.azure_auth.DeviceCodeCredential") +@patch("msticpy.auth.azure_auth.CredentialWrapper") +def test_fallback_devicecode_creds( + mock_cred_wrapper, mock_device_code_cred, mock_azure_cloud_config +): + mock_azure_cloud_config.return_value = MagicMock() + mock_device_code_cred.return_value = MagicMock() + mock_cred_wrapper.return_value = MagicMock() + + result = fallback_devicecode_creds(cloud="test_cloud", tenant_id="test_tenant") + + assert isinstance(result, AzCredentials) + mock_device_code_cred.assert_called_once() + mock_cred_wrapper.assert_called_once() + + +def test_get_default_resource_name(): + resource_uri = "https://example.com/resource" + expected_result = "https://example.com/resource/.default" + + result = get_default_resource_name(resource_uri) + + assert result == expected_result diff --git a/tests/auth/test_azure_auth_core.py b/tests/auth/test_azure_auth_core.py index 313c7ea2..bd835689 100644 --- a/tests/auth/test_azure_auth_core.py +++ b/tests/auth/test_azure_auth_core.py @@ -4,19 +4,34 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" +import logging +import os from datetime import datetime, timedelta +from typing import Tuple from unittest.mock import MagicMock, patch import pytest import pytest_check as check +from azure.identity import ChainedTokenCredential, DeviceCodeCredential from msticpy.auth.azure_auth_core import ( AzCredentials, AzureCliStatus, AzureCloudConfig, - DeviceCodeCredential, + MsticpyAzureConfigError, _az_connect_core, + _build_certificate_client, + _build_cli_client, + _build_client_secret_client, + _build_device_code_client, _build_env_client, + _build_interactive_client, + _build_msi_client, + _build_powershell_client, + _build_vscode_client, + _create_chained_credential, + _filter_all_warnings, + _filter_credential_warning, check_cli_credentials, ) from msticpy.auth.cloud_mappings import default_auth_methods @@ -119,7 +134,7 @@ def test_check_cli_credentials(get_cli_profile, test, expected): _CLI_ID = "d8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9d" _TENANT_ID = "f8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9e" -_TEST_ENV_VARS = ( +_TEST_ENV_VARS: list[Tuple[dict[str, str], bool]] = [ ( { "AZURE_CLIENT_ID": _CLI_ID, @@ -160,7 +175,7 @@ def test_check_cli_credentials(get_cli_profile, test, expected): False, ), ({}, False), -) +] @pytest.mark.parametrize("env_vars, expected", _TEST_ENV_VARS) @@ -192,28 +207,7 @@ def test_build_env_client(env_vars, expected, monkeypatch): ], ) def test_az_connect_core(auth_methods, cloud, tenant_id, silent, region, credential): - """ - Test _az_connect_core function with different parameters. - - Parameters - ---------- - auth_methods : list[str] - List of authentication methods to try. - cloud : str - Azure cloud to connect to. - tenant_id : str - Tenant to authenticate against. - silent : bool - Whether to display any output during auth process. - region : str - Azure region to connect to. - credential : AzCredentials - Azure credential to use directly. - - Returns - ------- - None - """ + """Test _az_connect_core function with different parameters.""" # Call the function with the test parameters result = _az_connect_core( auth_methods=auth_methods, @@ -228,3 +222,420 @@ def test_az_connect_core(auth_methods, cloud, tenant_id, silent, region, credent assert isinstance(result, AzCredentials) assert result.legacy is not None assert result.modern is not None + + +@pytest.mark.parametrize( + "env_vars, expected_credential", + [ + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_CLIENT_SECRET": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_CLIENT_CERTIFICATE_PATH": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_USERNAME": "test_user", + "AZURE_PASSWORD": "[PLACEHOLDER]", + }, + "EnvironmentCredential", + ), + ( + { + "AZURE_CLIENT_ID": "test_client_id", + "AZURE_CLIENT_CERTIFICATE_PATH": "[PLACEHOLDER]", + }, + None, + ), + ( + { + "AZURE_TENANT_ID": "test_tenant_id", + "AZURE_USERNAME": "test_user", + "AZURE_PASSWORD": "[PLACEHOLDER]", + }, + None, + ), + ({}, None), + ], +) +@patch.dict(os.environ, {}, clear=True) +@patch("msticpy.auth.azure_auth_core.EnvironmentCredential", autospec=True) +def test_build_env_client_alt( + mock_env_credential, env_vars, expected_credential, monkeypatch +): + """Test _build_env_client function.""" + for env_var, env_val in env_vars.items(): + monkeypatch.setenv(env_var, env_val) + result = _build_env_client() + if expected_credential: + # assert isinstance(result, mock_env_credential) + mock_env_credential.assert_called_once() + else: + mock_env_credential.assert_not_called() + assert result is None + + +@patch("msticpy.auth.azure_auth_core.AzureCliCredential", autospec=True) +def test_build_cli_client(mock_cli_credential): + """Test _build_cli_client function.""" + result = _build_cli_client() + # assert isinstance(result, mock_cli_credential) + mock_cli_credential.assert_called_once() + + +@pytest.mark.parametrize( + "env_vars, expected_kwargs, tenant_id, aad_uri", + [ + ( + {"AZURE_CLIENT_ID": "test_client_id"}, + {"client_id": "test_client_id"}, + "test_tenant_id", + "test_aad_uri", + ), + ({}, {}, None, None), + ], +) +@patch.dict(os.environ, {}, clear=True) +@patch("msticpy.auth.azure_auth_core.ManagedIdentityCredential", autospec=True) +def test_build_msi_client( + mock_msi_credential, env_vars, expected_kwargs, tenant_id, aad_uri +): + """Test _build_msi_client function.""" + os.environ.update(env_vars) + result = _build_msi_client(tenant_id=tenant_id, aad_uri=aad_uri) + # assert isinstance(result, mock_msi_credential) + mock_msi_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **expected_kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri", + [ + ("test_tenant_id", "test_aad_uri"), + (None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.VisualStudioCodeCredential", autospec=True) +def test_build_vscode_client(mock_vscode_credential, tenant_id, aad_uri): + """Test _build_vscode_client function.""" + result = _build_vscode_client(tenant_id=tenant_id, aad_uri=aad_uri) + # assert isinstance(result, mock_vscode_credential) + mock_vscode_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, kwargs", + [ + ("test_tenant_id", "test_aad_uri", {"param": "value"}), + (None, None, {}), + ], +) +@patch("msticpy.auth.azure_auth_core.InteractiveBrowserCredential", autospec=True) +def test_build_interactive_client( + mock_interactive_credential, tenant_id, aad_uri, kwargs +): + """Test _build_interactive_client function.""" + _ = _build_interactive_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + # assert isinstance(result, mock_interactive_credential) + mock_interactive_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, kwargs", + [ + ("test_tenant_id", "test_aad_uri", {"param": "value"}), + (None, None, {}), + ], +) +@patch("msticpy.auth.azure_auth_core.DeviceCodeCredential", autospec=True) +def test_build_device_code_client( + mock_device_code_credential, tenant_id, aad_uri, kwargs +): + """Test _build_device_code_client function.""" + _ = _build_device_code_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + mock_device_code_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, client_id, client_secret, expected_credential", + [ + ( + "test_tenant_id", + "test_aad_uri", + "test_client_id", + "test_client_secret", + "ClientSecretCredential", + ), + ("test_tenant_id", "test_aad_uri", None, "test_client_secret", None), + ("test_tenant_id", "test_aad_uri", "test_client_id", None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.ClientSecretCredential", autospec=True) +def test_build_client_secret_client( + mock_client_secret_credential, + tenant_id, + aad_uri, + client_id, + client_secret, + expected_credential, +): + """Test _build_client_secret_client function.""" + kwargs = {"client_id": client_id, "client_secret": client_secret} + result = _build_client_secret_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + if expected_credential: + mock_client_secret_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + else: + assert result is None + + +@pytest.mark.parametrize( + "tenant_id, aad_uri, client_id, expected_credential", + [ + ("test_tenant_id", "test_aad_uri", "test_client_id", "CertificateCredential"), + ("test_tenant_id", "test_aad_uri", None, None), + ], +) +@patch("msticpy.auth.azure_auth_core.CertificateCredential", autospec=True) +def test_build_certificate_client( + mock_certificate_credential, tenant_id, aad_uri, client_id, expected_credential +): + """Test _build_certificate_client function.""" + kwargs = {"client_id": client_id} + result = _build_certificate_client(tenant_id=tenant_id, aad_uri=aad_uri, **kwargs) + if expected_credential: + mock_certificate_credential.assert_called_once_with( + tenant_id=tenant_id, authority=aad_uri, **kwargs + ) + else: + assert result is None + + +@patch("msticpy.auth.azure_auth_core.AzurePowerShellCredential", autospec=True) +def test_build_powershell_client(mock_powershell_credential): + """Test _build_powershell_client function.""" + result = _build_powershell_client() + # assert isinstance(result, mock_powershell_credential) + mock_powershell_credential.assert_called_once() + + +@pytest.mark.parametrize( + "requested_clients, tenant_id, aad_uri, kwargs, expected_cred_types, expected_exception", + [ + ( + None, + "test_tenant_id", + "test_aad_uri", + {}, + [ + "AzureCliCredential", + "ManagedIdentityCredential", + "InteractiveBrowserCredential", + ], + None, + ), + ( + ["env", "cli"], + "test_tenant_id", + "test_aad_uri", + {}, + ["AzureCliCredential"], + None, + ), + ( + ["unknown"], + "test_tenant_id", + "test_aad_uri", + {}, + [], + MsticpyAzureConfigError, + ), + ( + ["env-test", "cli", "invalid"], + "test_tenant_id", + "test_aad_uri", + {}, + ["AzureCliCredential"], + None, + ), + ], +) +@patch("msticpy.auth.azure_auth_core.EnvironmentCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.AzureCliCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.ManagedIdentityCredential", autospec=True) +@patch("msticpy.auth.azure_auth_core.InteractiveBrowserCredential", autospec=True) +def test_create_chained_credential( + mock_interactive_credential, + mock_msi_credential, + mock_cli_credential, + mock_env_credential, + requested_clients, + tenant_id, + aad_uri, + kwargs, + expected_cred_types, + expected_exception, +): + """ + Test _create_chained_credential function. + + Parameters + ---------- + mock_interactive_credential : MagicMock + Mocked InteractiveBrowserCredential class. + mock_msi_credential : MagicMock + Mocked ManagedIdentityCredential class. + mock_cli_credential : MagicMock + Mocked AzureCliCredential class. + mock_env_credential : MagicMock + Mocked EnvironmentCredential class. + mock_clients : dict + Mocked _CLIENTS dictionary. + requested_clients : list[str] + List of clients to chain. + tenant_id : str + The tenant ID to connect to. + aad_uri : str + The URI of the Azure AD cloud to connect to. + kwargs : dict + Additional keyword arguments. + expected_cred_types : list[str] + Expected credential types to be included in the chained credential. + expected_exception : Exception + Expected exception to be raised. + + Returns + ------- + None + """ + if expected_exception: + with pytest.raises(expected_exception): + _create_chained_credential( + aad_uri=aad_uri, + requested_clients=requested_clients, + tenant_id=tenant_id, + **kwargs + ) + else: + result = _create_chained_credential( + aad_uri=aad_uri, + requested_clients=requested_clients, + tenant_id=tenant_id, + **kwargs + ) + assert isinstance(result, ChainedTokenCredential) + cred_classes = {cred.__class__.__name__ for cred in result.credentials} + assert all(expected in cred_classes for expected in expected_cred_types) + + +@pytest.mark.parametrize( + "record_name, record_level, record_message, expected_output", + [ + ("azure.identity", logging.WARNING, "EnvironmentCredential.get_token", False), + ("azure.identity", logging.WARNING, "AzureCliCredential.get_token", False), + ( + "azure.identity", + logging.WARNING, + "ManagedIdentityCredential.get_token", + False, + ), + ("azure.identity", logging.WARNING, "SomeOtherCredential.get_token", False), + ("azure.identity", logging.INFO, "EnvironmentCredential.get_token", True), + ("some.other.logger", logging.WARNING, "EnvironmentCredential.get_token", True), + ], +) +def test_filter_credential_warning( + record_name, record_level, record_message, expected_output +): + """ + Test _filter_credential_warning function. + + Parameters + ---------- + record_name : str + The name of the log record. + record_level : int + The level of the log record. + record_message : str + The message of the log record. + expected_output : bool + The expected output of the function. + + Returns + ------- + None + """ + record = MagicMock() + record.name = record_name + record.levelno = record_level + record.getMessage.return_value = record_message + + result = _filter_credential_warning(record) + assert result == expected_output + + +@pytest.mark.parametrize( + "record_name, record_level, record_message, expected_output", + [ + ("azure.identity", logging.WARNING, "EnvironmentCredential.get_token", False), + ("azure.identity", logging.WARNING, "AzureCliCredential.get_token", False), + ( + "azure.identity", + logging.WARNING, + "ManagedIdentityCredential.get_token", + False, + ), + ("azure.identity", logging.WARNING, "SomeOtherCredential.get_token", False), + ("azure.identity", logging.WARNING, "Some other warning message", True), + ("azure.identity", logging.INFO, "EnvironmentCredential.get_token", True), + ("some.other.logger", logging.WARNING, "EnvironmentCredential.get_token", True), + ], +) +def test_filter_all_warnings( + record_name, record_level, record_message, expected_output +): + """ + Test _filter_all_warnings function. + + Parameters + ---------- + record_name : str + The name of the log record. + record_level : int + The level of the log record. + record_message : str + The message of the log record. + expected_output : bool + The expected output of the function. + + Returns + ------- + None + """ + record = MagicMock() + record.name = record_name + record.levelno = record_level + record.getMessage.return_value = record_message + + result = _filter_all_warnings(record) + assert result == expected_output From 5279fbdf2e3e0fb0022f68482efc99bb96528f31 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Thu, 19 Sep 2024 18:08:29 -0700 Subject: [PATCH 2/9] also bugfix in lookup.py and test fail due to pkg dep --- .github/workflows/python-package.yml | 2 ++ msticpy/context/lookup.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 441d14f9..455c88a2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -54,6 +54,7 @@ jobs: fi python -m pip install -e . - name: Install test dependencies + # ToDo - remove pip install xgboost when flaml is fixed run: | if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt @@ -63,6 +64,7 @@ jobs: python -m pip install Pygments respx pytest-xdist markdown beautifulsoup4 Pillow async-cache lxml fi python -m pip install "pandas>=1.3.0" "pygeohash>=1.2.0" + python -m pip install "xgboost" - name: Prepare test dummy data run: | mkdir ~/.msticpy diff --git a/msticpy/context/lookup.py b/msticpy/context/lookup.py index ba21f226..33935b34 100644 --- a/msticpy/context/lookup.py +++ b/msticpy/context/lookup.py @@ -19,15 +19,7 @@ import logging import warnings from collections import ChainMap -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Iterable, - Mapping, - Sized, -) +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Mapping, Sized import nest_asyncio import pandas as pd @@ -59,6 +51,8 @@ logger: logging.Logger = logging.getLogger(__name__) +_HTTP_PROVIDER_LEGAL_KWARGS: list[str] = ["timeout", "ApiID", "AuthKey", "Instance"] + class ProgressCounter: """Progress counter for async tasks.""" @@ -811,7 +805,13 @@ def _load_providers( # instantiate class sending args from settings to init try: - provider_instance: Provider = provider_class(**(settings.args)) + # filter out any args that are not valid for the provider + provider_args = { + key: value + for key, value in settings.args.items() + if key in _HTTP_PROVIDER_LEGAL_KWARGS + } + provider_instance: Provider = provider_class(**(provider_args)) except MsticpyConfigError as mp_ex: # If the TI Provider didn't load, raise an exception err_msg: str = ( From fc2d8e893602506befe0651d0130fc06887bda0b Mon Sep 17 00:00:00 2001 From: Florian BRACQ Date: Fri, 20 Sep 2024 11:07:05 +0000 Subject: [PATCH 3/9] Explode kwargs arguments --- msticpy/auth/azure_auth.py | 23 +++++++++++++++-------- msticpy/auth/azure_auth_core.py | 19 ++++++++++++------- tests/auth/test_azure_auth.py | 5 ++++- tests/auth/test_azure_auth_core.py | 20 ++++++++++++++------ 4 files changed, 45 insertions(+), 22 deletions(-) diff --git a/msticpy/auth/azure_auth.py b/msticpy/auth/azure_auth.py index 3202f266..bc4bb709 100644 --- a/msticpy/auth/azure_auth.py +++ b/msticpy/auth/azure_auth.py @@ -4,8 +4,8 @@ # license information. # -------------------------------------------------------------------------- """Azure authentication handling.""" +from __future__ import annotations import os -from typing import List, Optional from azure.identity import DeviceCodeCredential from azure.mgmt.subscription import SubscriptionClient @@ -31,9 +31,11 @@ def az_connect( - auth_methods: Optional[List[str]] = None, - tenant_id: Optional[str] = None, + auth_methods: list[str] | None = None, + tenant_id: str | None = None, + *, silent: bool = False, + cloud: str | None = None, **kwargs, ) -> AzCredentials: """ @@ -76,7 +78,7 @@ def az_connect( list_auth_methods """ - az_cloud_config = AzureCloudConfig(cloud=kwargs.get("cloud")) + az_cloud_config = AzureCloudConfig(cloud=cloud) # Use auth_methods param or configuration defaults data_provs = get_provider_settings(config_section="DataProviders") auth_methods = auth_methods or az_cloud_config.auth_methods @@ -103,6 +105,7 @@ def az_connect( auth_methods=auth_methods, tenant_id=tenant_id, silent=silent, + cloud=cloud, **kwargs, ) sub_client = SubscriptionClient( @@ -120,7 +123,9 @@ def az_connect( def az_user_connect( - tenant_id: Optional[str] = None, silent: bool = False + tenant_id: str | None = None, + *, + silent: bool = False, ) -> AzCredentials: """ Authenticate to the SDK using user based authentication methods, Azure CLI or interactive logon. @@ -146,8 +151,10 @@ def az_user_connect( def fallback_devicecode_creds( - cloud: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs -): + cloud: str | None = None, + tenant_id: str | None = None, + region: str | None = None, +) -> AzCredentials: """ Authenticate using device code as a fallback method. @@ -173,7 +180,7 @@ def fallback_devicecode_creds( If chained token credential creation fails. """ - cloud = cloud or kwargs.pop("region", AzureCloudConfig().cloud) + cloud = cloud or region or AzureCloudConfig().cloud az_config = AzureCloudConfig(cloud) aad_uri = az_config.authority_uri tenant_id = tenant_id or az_config.tenant_id diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index 679ddb44..5c2f37e6 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar from azure.common.credentials import get_cli_profile from azure.core.credentials import TokenCredential @@ -135,9 +135,13 @@ def _build_env_client( return None -def _build_cli_client(**kwargs) -> AzureCliCredential: +def _build_cli_client( + tenant_id: str | None = None, + **kwargs, +) -> AzureCliCredential: """Build a credential from Azure CLI.""" - if tenant_id := kwargs.pop("tenant_id", None): + del kwargs + if tenant_id: return AzureCliCredential(tenant_id=tenant_id) return AzureCliCredential() @@ -145,16 +149,17 @@ def _build_cli_client(**kwargs) -> AzureCliCredential: def _build_msi_client( tenant_id: str | None = None, aad_uri: str | None = None, + client_id: str | None = None, **kwargs, ) -> ManagedIdentityCredential: """Build a credential from Managed Identity.""" - msi_kwargs = kwargs.copy() - if "client_id" not in kwargs and AzureCredEnvNames.AZURE_CLIENT_ID in os.environ: - msi_kwargs["client_id"] = os.environ[AzureCredEnvNames.AZURE_CLIENT_ID] + msi_kwargs: dict[str, Any] = kwargs.copy() + client_id = client_id or os.environ.get(AzureCredEnvNames.AZURE_CLIENT_ID) return ManagedIdentityCredential( tenant_id=tenant_id, authority=aad_uri, + client_id=client_id, **msi_kwargs, ) @@ -214,10 +219,10 @@ def _build_client_secret_client( def _build_certificate_client( tenant_id: str | None = None, aad_uri: str | None = None, + client_id: str | None = None, **kwargs, ) -> CertificateCredential | None: """Build a credential from Certificate.""" - client_id = kwargs.pop("client_id", None) if not client_id: logger.info( "'certificate' credential requested but client_id param not supplied" diff --git a/tests/auth/test_azure_auth.py b/tests/auth/test_azure_auth.py index 182e70af..59c94a4b 100644 --- a/tests/auth/test_azure_auth.py +++ b/tests/auth/test_azure_auth.py @@ -48,7 +48,10 @@ def test_az_connect( assert result == mock_az_credentials mock_az_connect_core.assert_called_once_with( - auth_methods=["env"], tenant_id="test_tenant", silent=True + auth_methods=["env"], + tenant_id="test_tenant", + silent=True, + cloud=None, ) mock_sub_client.assert_called_once() diff --git a/tests/auth/test_azure_auth_core.py b/tests/auth/test_azure_auth_core.py index bd835689..7c929c79 100644 --- a/tests/auth/test_azure_auth_core.py +++ b/tests/auth/test_azure_auth_core.py @@ -296,28 +296,36 @@ def test_build_cli_client(mock_cli_credential): @pytest.mark.parametrize( - "env_vars, expected_kwargs, tenant_id, aad_uri", + "env_vars, expected_kwargs, tenant_id, aad_uri, client_id", [ ( {"AZURE_CLIENT_ID": "test_client_id"}, - {"client_id": "test_client_id"}, + {}, "test_tenant_id", "test_aad_uri", + "test_client_id", ), - ({}, {}, None, None), + ({}, {}, None, None, None), ], ) @patch.dict(os.environ, {}, clear=True) @patch("msticpy.auth.azure_auth_core.ManagedIdentityCredential", autospec=True) def test_build_msi_client( - mock_msi_credential, env_vars, expected_kwargs, tenant_id, aad_uri + mock_msi_credential, + env_vars, + expected_kwargs, + tenant_id, + aad_uri, + client_id, ): """Test _build_msi_client function.""" os.environ.update(env_vars) - result = _build_msi_client(tenant_id=tenant_id, aad_uri=aad_uri) + result = _build_msi_client( + tenant_id=tenant_id, aad_uri=aad_uri, client_id=client_id + ) # assert isinstance(result, mock_msi_credential) mock_msi_credential.assert_called_once_with( - tenant_id=tenant_id, authority=aad_uri, **expected_kwargs + tenant_id=tenant_id, authority=aad_uri, client_id=client_id, **expected_kwargs ) From d40ddec5bcc598dba12e9b6041953c5f49508a4a Mon Sep 17 00:00:00 2001 From: Florian BRACQ Date: Fri, 20 Sep 2024 11:08:05 +0000 Subject: [PATCH 4/9] Update typing and extrat error messages from exceptions --- msticpy/auth/azure_auth.py | 30 ++++++++++++++++++++---------- msticpy/auth/azure_auth_core.py | 10 +++++----- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/msticpy/auth/azure_auth.py b/msticpy/auth/azure_auth.py index bc4bb709..0c59ec05 100644 --- a/msticpy/auth/azure_auth.py +++ b/msticpy/auth/azure_auth.py @@ -10,6 +10,8 @@ from azure.identity import DeviceCodeCredential from azure.mgmt.subscription import SubscriptionClient +from msticpy.common.provider_settings import ProviderSettings + from .._version import VERSION from ..common.exceptions import MsticpyAzureConnectionError @@ -80,12 +82,14 @@ def az_connect( """ az_cloud_config = AzureCloudConfig(cloud=cloud) # Use auth_methods param or configuration defaults - data_provs = get_provider_settings(config_section="DataProviders") + data_provs: dict[str, ProviderSettings] = get_provider_settings( + config_section="DataProviders" + ) auth_methods = auth_methods or az_cloud_config.auth_methods tenant_id = tenant_id or az_cloud_config.tenant_id # Ignore AzCLI settings except for authentication creds for EnvCred - az_cli_config = data_provs.get("AzureCLI") + az_cli_config: ProviderSettings | None = data_provs.get("AzureCLI") if ( az_cli_config and az_cli_config.args @@ -101,7 +105,7 @@ def az_connect( os.environ[AzureCredEnvNames.AZURE_CLIENT_SECRET] = ( az_cli_config.args.get("clientSecret") or "" ) - credentials = az_connect_core( + credentials: AzCredentials = az_connect_core( auth_methods=auth_methods, tenant_id=tenant_id, silent=silent, @@ -114,8 +118,9 @@ def az_connect( credential_scopes=[az_cloud_config.token_uri], ) if not sub_client: + err_msg: str = "Could not create an Azure Subscription client with credentials." raise MsticpyAzureConnectionError( - "Could not create an Azure Subscription client with credentials.", + err_msg, title="Azure authentication error", ) @@ -146,7 +151,9 @@ def az_user_connect( """ return az_connect_core( - auth_methods=["cli", "interactive"], tenant_id=tenant_id, silent=silent + auth_methods=["cli", "interactive"], + tenant_id=tenant_id, + silent=silent, ) @@ -181,15 +188,18 @@ def fallback_devicecode_creds( """ cloud = cloud or region or AzureCloudConfig().cloud - az_config = AzureCloudConfig(cloud) - aad_uri = az_config.authority_uri + az_config: AzureCloudConfig = AzureCloudConfig(cloud) + aad_uri: str = az_config.authority_uri tenant_id = tenant_id or az_config.tenant_id creds = DeviceCodeCredential(authority=aad_uri, tenant_id=tenant_id) legacy_creds = CredentialWrapper(creds, resource_id=az_config.token_uri) if not creds: + err_msg: str = ( + f"Could not obtain credentials for tenant {tenant_id}" + "Please check your Azure configuration and try again." + ) raise MsticpyAzureConnectionError( - f"Could not obtain credentials for tenant {tenant_id}", - "Please check your Azure configuration and try again.", + err_msg, title="Azure authentication error", ) @@ -198,5 +208,5 @@ def fallback_devicecode_creds( def get_default_resource_name(resource_uri: str) -> str: """Get a default resource name for a resource URI.""" - separator = "" if resource_uri.strip().endswith("/") else "/" + separator: str = "" if resource_uri.strip().endswith("/") else "/" return f"{resource_uri}{separator}.default" diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index 5c2f37e6..d685f943 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -242,7 +242,7 @@ def _build_powershell_client(**kwargs) -> AzurePowerShellCredential: return AzurePowerShellCredential() -_CLIENTS: dict[str, Callable] = dict( +_CLIENTS: dict[str, Callable[..., TokenCredential | None]] = dict( { "env": _build_env_client, "cli": _build_cli_client, @@ -411,15 +411,15 @@ def _create_chained_credential( if not requested_clients: requested_clients = ["env", "cli", "msi", "interactive"] logger.info("No auth methods requested defaulting to: %s", requested_clients) - cred_list = [] + cred_list: list[TokenCredential] = [] invalid_cred_types: list[str] = [] unusable_cred_type: list[str] = [] - for cred_type in requested_clients: # type: ignore[union-attr] + for cred_type in requested_clients: if cred_type not in _CLIENTS: invalid_cred_types.append(cred_type) logger.info("Unknown authentication type requested: %s", cred_type) continue - cred_client = _CLIENTS[cred_type]( + cred_client: TokenCredential | None = _CLIENTS[cred_type]( tenant_id=tenant_id, aad_uri=aad_uri, **kwargs, @@ -433,7 +433,7 @@ def _create_chained_credential( ", ".join(cred.__class__.__name__ for cred in cred_list if cred is not None), ) if not cred_list: - exception_args = [ + exception_args: list[str] = [ "Cannot authenticate - no valid credential types.", "At least one valid authentication method required.", f"Configured auth_types: {','.join(requested_clients)}", From 063d2acbd20f45b7aad64d9f5fce4844cbc5684a Mon Sep 17 00:00:00 2001 From: Florian BRACQ Date: Fri, 20 Sep 2024 11:51:24 +0000 Subject: [PATCH 5/9] replace Tuple with tuple --- tests/auth/test_azure_auth_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/auth/test_azure_auth_core.py b/tests/auth/test_azure_auth_core.py index 7c929c79..c9dc363f 100644 --- a/tests/auth/test_azure_auth_core.py +++ b/tests/auth/test_azure_auth_core.py @@ -4,10 +4,10 @@ # license information. # -------------------------------------------------------------------------- """Module docstring.""" +from __future__ import annotations import logging import os from datetime import datetime, timedelta -from typing import Tuple from unittest.mock import MagicMock, patch import pytest @@ -134,7 +134,7 @@ def test_check_cli_credentials(get_cli_profile, test, expected): _CLI_ID = "d8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9d" _TENANT_ID = "f8d9d2f2-5d2d-4d7e-9c5c-5d6d9d1d8d9e" -_TEST_ENV_VARS: list[Tuple[dict[str, str], bool]] = [ +_TEST_ENV_VARS: list[tuple[dict[str, str], bool]] = [ ( { "AZURE_CLIENT_ID": _CLI_ID, From 846272f8af42454ac9b3c1c5dbc0cedeccbca8f1 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Fri, 20 Sep 2024 10:29:04 -0700 Subject: [PATCH 6/9] Adding backward compat for named tuples - in ip_utils and azure_auth_core.py Fixing some import mocking itemrs in docs/source/conf.py --- docs/source/conf.py | 8 +++++++- msticpy/auth/azure_auth_core.py | 13 +++++++++++-- msticpy/context/ip_utils.py | 13 +++++++++++-- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6c06a962..119faa09 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -228,6 +228,8 @@ "azure.keyvault.secrets", "azure.keyvault", "azure.kusto.data", + "azure.kusto.data.helpers", + "azure.kusto.data.response", "azure.mgmt.compute.models", "azure.mgmt.compute", "azure.mgmt.keyvault.models", @@ -235,8 +237,11 @@ "azure.mgmt.monitor", "azure.mgmt.network", "azure.mgmt.resource", + "azure.mgmt.resource.subscriptions", "azure.mgmt.resourcegraph", + "azure.mgmt.resourcegraph.models", "azure.mgmt.subscription", + "azure.mgmt.subscription.models", "azure.monitor.query", "azure.storage.blob", "azure.storage", @@ -250,11 +255,12 @@ "ipwhois", "IPython", "ipywidgets", + "jwt", "keyring", "Kqlmagic", "matplotlib.pyplot", "matplotlib", - "mo-sql-parsing", + "mo_sql_parsing", "msal", "msal_extensions", "msrest", diff --git a/msticpy/auth/azure_auth_core.py b/msticpy/auth/azure_auth_core.py index d685f943..34169b9f 100644 --- a/msticpy/auth/azure_auth_core.py +++ b/msticpy/auth/azure_auth_core.py @@ -9,10 +9,10 @@ import logging import os import sys -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import datetime from enum import Enum -from typing import Any, Callable, ClassVar +from typing import Any, Callable, ClassVar, Iterator from azure.common.credentials import get_cli_profile from azure.core.credentials import TokenCredential @@ -55,6 +55,15 @@ class AzCredentials: legacy: TokenCredential modern: ChainedTokenCredential + # Backward compatibility with namedtuple + def __iter__(self) -> Iterator[Any]: + """Iterate over properties.""" + return iter(asdict(self).values()) + + def __getitem__(self, item) -> Any: + """Get item from properties.""" + return list(asdict(self).values())[item] + # pylint: disable=too-few-public-methods class AzureCredEnvNames: diff --git a/msticpy/context/ip_utils.py b/msticpy/context/ip_utils.py index cc821d9b..0f8796a1 100644 --- a/msticpy/context/ip_utils.py +++ b/msticpy/context/ip_utils.py @@ -19,10 +19,10 @@ import re import socket import warnings -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache from time import sleep -from typing import Any, Callable +from typing import Any, Callable, Iterator import httpx import pandas as pd @@ -604,6 +604,15 @@ class _IpWhoIsResult: name: str | None = None properties: dict[str, Any] = field(default_factory=dict) + # Backward compatibility with namedtuple + def __iter__(self) -> Iterator[Any]: + """Iterate over properties.""" + return iter(asdict(self).values()) + + def __getitem__(self, item): + """Get item from properties.""" + return list(asdict(self).values())[item] + @lru_cache(maxsize=1024) def _whois_lookup( From 48a53f76bb34e48945ca780f3bca5c2a9e49a502 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Fri, 20 Sep 2024 11:02:50 -0700 Subject: [PATCH 7/9] adding backward compat params to ti lookup_iocs check for no results before trying pd.concat in provider_base.py --- msticpy/context/provider_base.py | 2 ++ msticpy/context/tilookup.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/msticpy/context/provider_base.py b/msticpy/context/provider_base.py index 2ff1ef03..b7348849 100644 --- a/msticpy/context/provider_base.py +++ b/msticpy/context/provider_base.py @@ -211,6 +211,8 @@ def lookup_items( ) results.append(item_result) + if not results: + return pd.DataFrame() return pd.concat(results) async def lookup_items_async( # noqa:PLR0913 diff --git a/msticpy/context/tilookup.py b/msticpy/context/tilookup.py index 4a85374d..809e7b88 100644 --- a/msticpy/context/tilookup.py +++ b/msticpy/context/tilookup.py @@ -145,6 +145,8 @@ def lookup_iocs( # pylint: disable=too-many-arguments #noqa: PLR0913 *, start: dt.datetime | None = None, end: dt.datetime | None = None, + col: str | None = None, + column: str | None = None, ) -> pd.DataFrame: """ Lookup Threat Intelligence reports for a collection of IoCs in active providers. @@ -200,7 +202,7 @@ def lookup_iocs( # pylint: disable=too-many-arguments #noqa: PLR0913 return _make_sync( self._lookup_iocs_async( data=data, - ioc_col=ioc_col, + ioc_col=ioc_col or column or col, ioc_type_col=ioc_type_col, ioc_query_type=ioc_query_type, providers=providers, From e614ee939475d5b8e4873caeb6f62204ae081d87 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Fri, 20 Sep 2024 11:05:17 -0700 Subject: [PATCH 8/9] Updating version to 2.13.2 --- msticpy/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msticpy/_version.py b/msticpy/_version.py index fb0fc2d9..bf6f6b7d 100644 --- a/msticpy/_version.py +++ b/msticpy/_version.py @@ -1,3 +1,3 @@ """Version file.""" -VERSION = "2.13.1" +VERSION = "2.13.2" From 7b98465e0329cefecbf651f9f58fabcc834178b5 Mon Sep 17 00:00:00 2001 From: ianhelle Date: Fri, 20 Sep 2024 13:51:12 -0700 Subject: [PATCH 9/9] Suppressing pylint too-many-positional-arguments --- .pylintrc | 1 + 1 file changed, 1 insertion(+) diff --git a/.pylintrc b/.pylintrc index 6ac718a2..ffd2524c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -68,6 +68,7 @@ disable=raw-checker-failed, useless-suppression, deprecated-pragma, use-symbolic-message-instead, + too-many-positional-arguments, # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option