Skip to content

Commit

Permalink
Added AAD support
Browse files Browse the repository at this point in the history
  • Loading branch information
jorge-beauregard committed Jan 11, 2021
1 parent be3b227 commit 56e1431
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
import six
from msrest.serialization import TZ_UTC
from .utils import create_access_token
from .utils import create_access_token, get_authentication_policy

class CommunicationUserCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,31 @@ def create_access_token(token):
return AccessToken(token, datetime.fromtimestamp(payload['exp']).replace(tzinfo=TZ_UTC))
except ValueError:
raise ValueError(token_parse_err_msg)

def get_authentication_policy(
endpoint, # type: str
credential # type: TokenCredential or str
):
# type: (...) -> BearerTokenCredentialPolicy or HMACCredentialPolicy
"""Returns the correct authentication policy based
on which credential is being passed.
:param endpoint: The endpoint to which we are authenticating to.
:type endpoint: str
:param credential: The credential we use to authenticate to the service
:type credential: TokenCredential or str
:rtype: ~azure.core.pipeline.policies.BearerTokenCredentialPolicy
~HMACCredentialsPolicy
"""

if credential is None:
raise ValueError("Parameter 'credential' must not be None.")
if hasattr(credential, "get_token"):
from azure.core.pipeline.policies import BearerTokenCredentialPolicy
return BearerTokenCredentialPolicy(
credential, "https://communication.azure.com//.default")
if isinstance(credential, str):
from .._shared.policy import HMACCredentialsPolicy
return HMACCredentialsPolicy(endpoint, credential)

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from azure.communication.sms._generated.models import SendSmsResponse

from ._generated._azure_communication_sms_service import AzureCommunicationSMSService
from ._shared.policy import HMACCredentialsPolicy
from ._shared.utils import parse_connection_str
from ._shared.utils import parse_connection_str, get_authentication_policy
from ._version import SDK_MONIKER

class SmsClient(object):
Expand Down Expand Up @@ -41,8 +40,7 @@ def __init__(
"invalid credential from connection string.")

self._endpoint = endpoint
self._authentication_policy = HMACCredentialsPolicy(endpoint, credential)

self._authentication_policy = get_authentication_policy(endpoint, credential)
self._sms_service_client = AzureCommunicationSMSService(
self._endpoint,
authentication_policy=self._authentication_policy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from azure.communication.sms._generated.models import SendSmsResponse

from .._generated.aio._azure_communication_sms_service import AzureCommunicationSMSService
from .._shared.policy import HMACCredentialsPolicy
from .._shared.utils import parse_connection_str
from .._shared.utils import parse_connection_str, get_authentication_policy
from .._version import SDK_MONIKER

class SmsClient(object):
Expand Down Expand Up @@ -41,7 +40,7 @@ def __init__(
"invalid credential from connection string.")

self._endpoint = endpoint
self._authentication_policy = HMACCredentialsPolicy(endpoint, credential)
self._authentication_policy = get_authentication_policy(endpoint, credential)

self._sms_service_client = AzureCommunicationSMSService(
self._endpoint,
Expand Down

0 comments on commit 56e1431

Please sign in to comment.