Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Authentication module unit test #800

Merged
merged 9 commits into from
Sep 23, 2024
2 changes: 2 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
fi
python -m pip install -e .
- name: Install test dependencies
# ToDo - remove pip install xgboost when flaml is fixed
run: |
if [ -f requirements-dev.txt ]; then
python -m pip install -r requirements-dev.txt
Expand All @@ -63,6 +64,7 @@ jobs:
python -m pip install Pygments respx pytest-xdist markdown beautifulsoup4 Pillow async-cache lxml
fi
python -m pip install "pandas>=1.3.0" "pygeohash>=1.2.0"
python -m pip install "xgboost"
- name: Prepare test dummy data
run: |
mkdir ~/.msticpy
Expand Down
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ disable=raw-checker-failed,
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
too-many-positional-arguments,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
8 changes: 7 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,20 @@
"azure.keyvault.secrets",
"azure.keyvault",
"azure.kusto.data",
"azure.kusto.data.helpers",
"azure.kusto.data.response",
"azure.mgmt.compute.models",
"azure.mgmt.compute",
"azure.mgmt.keyvault.models",
"azure.mgmt.keyvault",
"azure.mgmt.monitor",
"azure.mgmt.network",
"azure.mgmt.resource",
"azure.mgmt.resource.subscriptions",
"azure.mgmt.resourcegraph",
"azure.mgmt.resourcegraph.models",
"azure.mgmt.subscription",
"azure.mgmt.subscription.models",
"azure.monitor.query",
"azure.storage.blob",
"azure.storage",
Expand All @@ -250,11 +255,12 @@
"ipwhois",
"IPython",
"ipywidgets",
"jwt",
"keyring",
"Kqlmagic",
"matplotlib.pyplot",
"matplotlib",
"mo-sql-parsing",
"mo_sql_parsing",
"msal",
"msal_extensions",
"msrest",
Expand Down
2 changes: 1 addition & 1 deletion msticpy/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version file."""

VERSION = "2.13.1"
VERSION = "2.13.2"
75 changes: 50 additions & 25 deletions msticpy/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
# license information.
# --------------------------------------------------------------------------
"""Azure authentication handling."""
from __future__ import annotations
import os
from typing import List, Optional

from azure.common.exceptions import CloudError
from azure.identity import DeviceCodeCredential
from azure.mgmt.subscription import SubscriptionClient

from msticpy.common.provider_settings import ProviderSettings

from .._version import VERSION
from ..common.exceptions import MsticpyAzureConnectionError

# pylint: enable=unused-import
from ..common.provider_settings import get_provider_settings
Expand All @@ -31,9 +33,11 @@


def az_connect(
auth_methods: Optional[List[str]] = None,
tenant_id: Optional[str] = None,
auth_methods: list[str] | None = None,
tenant_id: str | None = None,
*,
silent: bool = False,
cloud: str | None = None,
**kwargs,
) -> AzCredentials:
"""
Expand Down Expand Up @@ -68,22 +72,24 @@ def az_connect(

Raises
------
CloudError
MsticpyAzureConnectionError
If chained token credential creation fails.

See Also
--------
list_auth_methods

"""
az_cloud_config = AzureCloudConfig(cloud=kwargs.get("cloud"))
az_cloud_config = AzureCloudConfig(cloud=cloud)
# Use auth_methods param or configuration defaults
data_provs = get_provider_settings(config_section="DataProviders")
data_provs: dict[str, ProviderSettings] = get_provider_settings(
config_section="DataProviders"
)
auth_methods = auth_methods or az_cloud_config.auth_methods
tenant_id = tenant_id or az_cloud_config.tenant_id

# Ignore AzCLI settings except for authentication creds for EnvCred
az_cli_config = data_provs.get("AzureCLI")
az_cli_config: ProviderSettings | None = data_provs.get("AzureCLI")
if (
az_cli_config
and az_cli_config.args
Expand All @@ -99,10 +105,11 @@ def az_connect(
os.environ[AzureCredEnvNames.AZURE_CLIENT_SECRET] = (
az_cli_config.args.get("clientSecret") or ""
)
credentials = az_connect_core(
credentials: AzCredentials = az_connect_core(
auth_methods=auth_methods,
tenant_id=tenant_id,
silent=silent,
cloud=cloud,
**kwargs,
)
sub_client = SubscriptionClient(
Expand All @@ -111,13 +118,19 @@ def az_connect(
credential_scopes=[az_cloud_config.token_uri],
)
if not sub_client:
raise CloudError("Could not create a Subscription client.")
err_msg: str = "Could not create an Azure Subscription client with credentials."
raise MsticpyAzureConnectionError(
err_msg,
title="Azure authentication error",
)

return credentials


def az_user_connect(
tenant_id: Optional[str] = None, silent: bool = False
tenant_id: str | None = None,
*,
silent: bool = False,
) -> AzCredentials:
"""
Authenticate to the SDK using user based authentication methods, Azure CLI or interactive logon.
Expand All @@ -132,17 +145,23 @@ def az_user_connect(

Returns
-------
AzCredentials
AzCredentials - Dataclass combining two types of Azure credentials:
- legacy (ADAL) credentials
- modern (MSAL) credentials

"""
return az_connect_core(
auth_methods=["cli", "interactive"], tenant_id=tenant_id, silent=silent
auth_methods=["cli", "interactive"],
tenant_id=tenant_id,
silent=silent,
)


def fallback_devicecode_creds(
cloud: Optional[str] = None, tenant_id: Optional[str] = None, **kwargs
):
cloud: str | None = None,
tenant_id: str | None = None,
region: str | None = None,
) -> AzCredentials:
"""
Authenticate using device code as a fallback method.

Expand All @@ -158,30 +177,36 @@ def fallback_devicecode_creds(

Returns
-------
AzCredentials
Named tuple of:
- legacy (ADAL) credentials
- modern (MSAL) credentials
AzCredentials - Dataclass combining two types of Azure credentials:
- legacy (ADAL) credentials
- modern (MSAL) credentials

Raises
------
CloudError
MsticpyAzureConnectionError
If chained token credential creation fails.

"""
cloud = cloud or kwargs.pop("region", AzureCloudConfig().cloud)
az_config = AzureCloudConfig(cloud)
aad_uri = az_config.authority_uri
cloud = cloud or region or AzureCloudConfig().cloud
az_config: AzureCloudConfig = AzureCloudConfig(cloud)
aad_uri: str = az_config.authority_uri
tenant_id = tenant_id or az_config.tenant_id
creds = DeviceCodeCredential(authority=aad_uri, tenant_id=tenant_id)
legacy_creds = CredentialWrapper(creds, resource_id=az_config.token_uri)
if not creds:
raise CloudError("Could not obtain credentials.")
err_msg: str = (
f"Could not obtain credentials for tenant {tenant_id}"
"Please check your Azure configuration and try again."
)
raise MsticpyAzureConnectionError(
err_msg,
title="Azure authentication error",
)

return AzCredentials(legacy_creds, ChainedTokenCredential(creds)) # type: ignore[arg-type]


def get_default_resource_name(resource_uri: str) -> str:
"""Get a default resource name for a resource URI."""
separator = "" if resource_uri.strip().endswith("/") else "/"
separator: str = "" if resource_uri.strip().endswith("/") else "/"
return f"{resource_uri}{separator}.default"
39 changes: 27 additions & 12 deletions msticpy/auth/azure_auth_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import logging
import os
import sys
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import datetime
from enum import Enum
from typing import Callable, ClassVar
from typing import Any, Callable, ClassVar, Iterator

from azure.common.credentials import get_cli_profile
from azure.core.credentials import TokenCredential
Expand Down Expand Up @@ -55,6 +55,15 @@ class AzCredentials:
legacy: TokenCredential
modern: ChainedTokenCredential

# Backward compatibility with namedtuple
def __iter__(self) -> Iterator[Any]:
"""Iterate over properties."""
return iter(asdict(self).values())

def __getitem__(self, item) -> Any:
"""Get item from properties."""
return list(asdict(self).values())[item]


# pylint: disable=too-few-public-methods
class AzureCredEnvNames:
Expand Down Expand Up @@ -135,25 +144,31 @@ def _build_env_client(
return None


def _build_cli_client(**kwargs) -> AzureCliCredential:
def _build_cli_client(
tenant_id: str | None = None,
**kwargs,
) -> AzureCliCredential:
"""Build a credential from Azure CLI."""
del kwargs
if tenant_id:
return AzureCliCredential(tenant_id=tenant_id)
return AzureCliCredential()


def _build_msi_client(
tenant_id: str | None = None,
aad_uri: str | None = None,
client_id: str | None = None,
**kwargs,
) -> ManagedIdentityCredential:
"""Build a credential from Managed Identity."""
msi_kwargs = kwargs.copy()
if AzureCredEnvNames.AZURE_CLIENT_ID in os.environ:
msi_kwargs["client_id"] = os.environ[AzureCredEnvNames.AZURE_CLIENT_ID]
msi_kwargs: dict[str, Any] = kwargs.copy()
client_id = client_id or os.environ.get(AzureCredEnvNames.AZURE_CLIENT_ID)

return ManagedIdentityCredential(
tenant_id=tenant_id,
authority=aad_uri,
client_id=client_id,
**msi_kwargs,
)

Expand Down Expand Up @@ -213,10 +228,10 @@ def _build_client_secret_client(
def _build_certificate_client(
tenant_id: str | None = None,
aad_uri: str | None = None,
client_id: str | None = None,
**kwargs,
) -> CertificateCredential | None:
"""Build a credential from Certificate."""
client_id = kwargs.pop("client_id", None)
if not client_id:
logger.info(
"'certificate' credential requested but client_id param not supplied"
Expand All @@ -236,7 +251,7 @@ def _build_powershell_client(**kwargs) -> AzurePowerShellCredential:
return AzurePowerShellCredential()


_CLIENTS: dict[str, Callable] = dict(
_CLIENTS: dict[str, Callable[..., TokenCredential | None]] = dict(
{
"env": _build_env_client,
"cli": _build_cli_client,
Expand Down Expand Up @@ -405,15 +420,15 @@ def _create_chained_credential(
if not requested_clients:
requested_clients = ["env", "cli", "msi", "interactive"]
logger.info("No auth methods requested defaulting to: %s", requested_clients)
cred_list = []
cred_list: list[TokenCredential] = []
invalid_cred_types: list[str] = []
unusable_cred_type: list[str] = []
for cred_type in requested_clients: # type: ignore[union-attr]
for cred_type in requested_clients:
if cred_type not in _CLIENTS:
invalid_cred_types.append(cred_type)
logger.info("Unknown authentication type requested: %s", cred_type)
continue
cred_client = _CLIENTS[cred_type](
cred_client: TokenCredential | None = _CLIENTS[cred_type](
tenant_id=tenant_id,
aad_uri=aad_uri,
**kwargs,
Expand All @@ -427,7 +442,7 @@ def _create_chained_credential(
", ".join(cred.__class__.__name__ for cred in cred_list if cred is not None),
)
if not cred_list:
exception_args = [
exception_args: list[str] = [
"Cannot authenticate - no valid credential types.",
"At least one valid authentication method required.",
f"Configured auth_types: {','.join(requested_clients)}",
Expand Down
13 changes: 11 additions & 2 deletions msticpy/context/ip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import re
import socket
import warnings
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from functools import lru_cache
from time import sleep
from typing import Any, Callable
from typing import Any, Callable, Iterator

import httpx
import pandas as pd
Expand Down Expand Up @@ -604,6 +604,15 @@ class _IpWhoIsResult:
name: str | None = None
properties: dict[str, Any] = field(default_factory=dict)

# Backward compatibility with namedtuple
def __iter__(self) -> Iterator[Any]:
"""Iterate over properties."""
return iter(asdict(self).values())

def __getitem__(self, item):
"""Get item from properties."""
return list(asdict(self).values())[item]


@lru_cache(maxsize=1024)
def _whois_lookup(
Expand Down
Loading