diff --git a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md index 9c628f23b4019..fda279009b7ee 100644 --- a/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-certificates/CHANGELOG.md @@ -1,7 +1,9 @@ # Release History ## 4.0.2 (Unreleased) - +- `CertificateClient` instances have a `close` method which closes opened +sockets. Used as a context manager, a `CertificateClient` closes opened sockets +on exit. ([#9906](https://github.com/Azure/azure-sdk-for-python/pull/9906)) ## 4.0.1 (2020-02-11) - `azure.keyvault.certificates` defines `__version__` diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_client.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_client.py index 8133cff899908..d47cfab24b058 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_client.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_client.py @@ -97,7 +97,7 @@ def begin_create_certificate(self, certificate_name, policy, **kwargs): tags = kwargs.pop("tags", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None @@ -359,7 +359,7 @@ def import_certificate(self, certificate_name, certificate_bytes, **kwargs): policy = kwargs.pop("policy", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None base64_encoded_certificate = base64.b64encode(certificate_bytes).decode("utf-8") @@ -442,7 +442,7 @@ def update_certificate_properties(self, certificate_name, version=None, **kwargs enabled = kwargs.pop("enabled", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None @@ -760,7 +760,7 @@ def merge_certificate(self, certificate_name, x509_certificates, **kwargs): enabled = kwargs.pop("enabled", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None bundle = self._client.merge_certificate( @@ -832,12 +832,12 @@ def create_issuer(self, issuer_name, provider, **kwargs): admin_contacts = kwargs.pop("admin_contacts", None) if account_id or password: - issuer_credentials = self._client.models.IssuerCredentials(account_id=account_id, password=password) + issuer_credentials = self._models.IssuerCredentials(account_id=account_id, password=password) else: issuer_credentials = None if admin_contacts: admin_details = [ - self._client.models.AdministratorDetails( + self._models.AdministratorDetails( first_name=contact.first_name, last_name=contact.last_name, email_address=contact.email, @@ -848,13 +848,13 @@ def create_issuer(self, issuer_name, provider, **kwargs): else: admin_details = None if organization_id or admin_details: - organization_details = self._client.models.OrganizationDetails( + organization_details = self._models.OrganizationDetails( id=organization_id, admin_details=admin_details ) else: organization_details = None if enabled is not None: - issuer_attributes = self._client.models.IssuerAttributes(enabled=enabled) + issuer_attributes = self._models.IssuerAttributes(enabled=enabled) else: issuer_attributes = None issuer_bundle = self._client.set_certificate_issuer( @@ -895,12 +895,12 @@ def update_issuer(self, issuer_name, **kwargs): admin_contacts = kwargs.pop("admin_contacts", None) if account_id or password: - issuer_credentials = self._client.models.IssuerCredentials(account_id=account_id, password=password) + issuer_credentials = self._models.IssuerCredentials(account_id=account_id, password=password) else: issuer_credentials = None if admin_contacts: admin_details = [ - self._client.models.AdministratorDetails( + self._models.AdministratorDetails( first_name=contact.first_name, last_name=contact.last_name, email_address=contact.email, @@ -911,13 +911,13 @@ def update_issuer(self, issuer_name, **kwargs): else: admin_details = None if organization_id or admin_details: - organization_details = self._client.models.OrganizationDetails( + organization_details = self._models.OrganizationDetails( id=organization_id, admin_details=admin_details ) else: organization_details = None if enabled is not None: - issuer_attributes = self._client.models.IssuerAttributes(enabled=enabled) + issuer_attributes = self._models.IssuerAttributes(enabled=enabled) else: issuer_attributes = None issuer_bundle = self._client.update_certificate_issuer( diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/__init__.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/__init__.py index efc5f67755a24..b74cfa3b899cc 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/__init__.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/__init__.py @@ -2,6 +2,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from .key_vault_client import KeyVaultClient - -__all__ = ["KeyVaultClient"] diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/key_vault_client.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/key_vault_client.py deleted file mode 100644 index fb29f4364313b..0000000000000 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/_generated/key_vault_client.py +++ /dev/null @@ -1,199 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -from azure.profiles import KnownProfiles, ProfileDefinition -from azure.profiles.multiapiclient import MultiApiClientMixin - -from .v7_0.version import VERSION as V7_0_VERSION -from .v2016_10_01.version import VERSION as V2016_10_01_VERSION - - -class KeyVaultClient(MultiApiClientMixin): - """The key vault client performs cryptographic key operations and vault operations against the Key Vault service. - Implementation depends on the API version: - - * 2016-10-01: :class:`v2016_10_01.KeyVaultClient` - * 7.0: :class:`v7_0.KeyVaultClient` - - :param credentials: Credentials needed for the client to connect to Azure. - :type credentials: :mod:`A msrestazure Credentials - object` - - :param str api_version: API version to use if no profile is provided, or if - missing in profile. - :param profile: A profile definition, from KnownProfiles to dict. - :type profile: azure.profiles.KnownProfiles - """ - - DEFAULT_API_VERSION = V7_0_VERSION - _PROFILE_TAG = "azure.keyvault.KeyVaultClient" - LATEST_PROFILE = ProfileDefinition({_PROFILE_TAG: {None: DEFAULT_API_VERSION}}, _PROFILE_TAG + " latest") - - _init_complete = False - - def __init__(self, credentials, pipeline=None, api_version=None, aio=False, profile=KnownProfiles.default): - self._client_impls = {} - self._pipeline = pipeline - self._entered = False - self._aio = aio - super(KeyVaultClient, self).__init__(api_version=api_version, profile=profile) - - self._credentials = credentials - self._init_complete = True - - @staticmethod - def get_configuration_class(api_version, aio=False): - """ - Get the versioned configuration implementation corresponding to the current profile. - :return: The versioned configuration implementation. - """ - if api_version == V7_0_VERSION: - if aio: - from .v7_0.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v7_0._configuration import KeyVaultClientConfiguration as ImplConfig - elif api_version == V2016_10_01_VERSION: - if aio: - from .v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v2016_10_01._configuration import KeyVaultClientConfiguration as ImplConfig - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - return ImplConfig - - @property - def models(self): - """Module depends on the API version: - * 2016-10-01: :mod:`v2016_10_01.models` - * 7.0: :mod:`v7_0.models` - """ - api_version = self._get_api_version(None) - - if api_version == V7_0_VERSION: - from .v7_0 import models as impl_models - elif api_version == V2016_10_01_VERSION: - from .v2016_10_01 import models as impl_models - else: - raise NotImplementedError("APIVersion {} is not available".format(api_version)) - return impl_models - - def _get_client_impl(self): - """ - Get the versioned client implementation corresponding to the current profile. - :return: The versioned client implementation. - """ - api_version = self._get_api_version(None) - if api_version not in self._client_impls: - self._create_client_impl(api_version) - return self._client_impls[api_version] - - def _create_client_impl(self, api_version): - """ - Creates the client implementation corresponding to the specified api_version. - :param api_version: - :return: - """ - if api_version == V7_0_VERSION: - if self._aio: - from .v7_0.aio import KeyVaultClient as ImplClient - else: - from .v7_0 import KeyVaultClient as ImplClient - elif api_version == V2016_10_01_VERSION: - if self._aio: - from .v2016_10_01.aio import KeyVaultClient as ImplClient - else: - from .v2016_10_01 import KeyVaultClient as ImplClient - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - - impl = ImplClient(credentials=self._credentials, pipeline=self._pipeline) - - # if __enter__ has previously been called and the impl client has __enter__ defined we need to call it - if self._entered: - if hasattr(impl, "__enter__"): - impl.__enter__() - elif hasattr(impl, "__aenter__"): - impl.__aenter__() - - self._client_impls[api_version] = impl - return impl - - def __aenter__(self, *args, **kwargs): - """ - Calls __aenter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aenter__ - :param kwargs: keyword arguments to relay to client implementations of __aenter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aenter__"): - impl.__aenter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __aenter__ called on them as appropriate - self._entered = True - return self - - def __enter__(self, *args, **kwargs): - """ - Calls __enter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__enter__"): - impl.__enter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __enter__ called on them as appropriate - self._entered = True - return self - - def __aexit__(self, *args, **kwargs): - """ - Calls __aexit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aexit__ - :param kwargs: keyword arguments to relay to client implementations of __aexit__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aexit__"): - impl.__aexit__(*args, **kwargs) - return self - - def __exit__(self, *args, **kwargs): - """ - Calls __exit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__exit__"): - impl.__exit__(*args, **kwargs) - return self - - def __getattr__(self, name): - """ - In the case that the attribute is not defined on the custom KeyVaultClient. Attempt to get - the attribute from the versioned client implementation corresponding to the current profile. - :param name: Name of the attribute retrieve from the current versioned client implementation - :return: The value of the specified attribute on the current client implementation. - """ - impl = self._get_client_impl() - return getattr(impl, name) - - def __setattr__(self, name, attr): - """ - Sets the specified attribute either on the custom KeyVaultClient or the current underlying implementation. - :param name: Name of the attribute to set - :param attr: Value of the attribute to set - :return: None - """ - if self._init_complete and not hasattr(self, name): - impl = self._get_client_impl() - setattr(impl, name, attr) - else: - super(KeyVaultClient, self).__setattr__(name, attr) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py index 6f4a1e563077f..84c827a3a938e 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/async_client_base.py @@ -2,108 +2,80 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING -from azure.core.configuration import Configuration from azure.core.pipeline import AsyncPipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy -) -from azure.core.pipeline.transport import AsyncHttpTransport -from ._generated import KeyVaultClient from . import AsyncChallengeAuthPolicy -from .._user_agent import USER_AGENT - +from .client_base import _get_policies +from .multi_api import load_generated_api if TYPE_CHECKING: try: # pylint:disable=unused-import - from azure.core.credentials import TokenCredential + from typing import Any + from azure.core.configuration import Configuration + from azure.core.pipeline.transport import AsyncHttpTransport + from azure.core.credentials_async import AsyncTokenCredential except ImportError: - # TokenCredential is a typing_extensions.Protocol; we don't depend on that package + # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package pass -class AsyncKeyVaultClientBase: - """Base class for async Key Vault clients""" - - @staticmethod - def _create_config(credential: "TokenCredential", api_version: str = None, **kwargs: "Any") -> Configuration: - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs) - config.authentication_policy = AsyncChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config - - def __init__(self, vault_url: str, credential: "TokenCredential", **kwargs: "Any") -> None: +def _build_pipeline(config: "Configuration", transport: "AsyncHttpTransport" = None, **kwargs: "Any") -> AsyncPipeline: + policies = _get_policies(config, **kwargs) + if transport is None: + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline(transport, policies=policies) + + +class AsyncKeyVaultClientBase(object): + def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: if not credential: raise ValueError( - "credential should be an object supporting the TokenCredential protocol, " + "credential should be an object supporting the AsyncTokenCredential protocol, " "such as a credential from azure-identity" ) if not vault_url: raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=True) - - @staticmethod - def _build_pipeline(config: Configuration, transport: AsyncHttpTransport, **kwargs: "Any") -> AsyncPipeline: - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - from azure.core.pipeline.transport import AioHttpTransport - - transport = AioHttpTransport(**kwargs) - - return AsyncPipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version, aio=True) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = AsyncChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self) -> str: return self._vault_url + + async def __aenter__(self) -> "AsyncKeyVaultClientBase": + await self._client.__aenter__() + return self + + async def __aexit__(self, *args: "Any") -> None: + await self._client.__aexit__(*args) + + async def close(self) -> None: + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + await self._client.__aexit__() diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py index 1e1b7692d379a..b1e1a2e997d2e 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/client_base.py @@ -5,58 +5,54 @@ from typing import TYPE_CHECKING from azure.core.pipeline import Pipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + UserAgentPolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, ) from azure.core.pipeline.transport import RequestsTransport -from ._generated import KeyVaultClient + +from .multi_api import load_generated_api from .challenge_auth_policy import ChallengeAuthPolicy from .._user_agent import USER_AGENT if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Optional + from typing import Any from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpTransport from azure.core.configuration import Configuration -KEY_VAULT_SCOPE = "https://vault.azure.net/.default" +def _get_policies(config, **kwargs): + logging_policy = HttpLoggingPolicy(**kwargs) + logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") -class KeyVaultClientBase(object): - """Base class for Key Vault clients""" - - @staticmethod - def _create_config(credential, api_version=None, **kwargs): - # type: (TokenCredential, Optional[str], **Any) -> Configuration - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs) - config.authentication_policy = ChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config + return [ + config.headers_policy, + UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs), + config.proxy_policy, + ContentDecodePolicy(), + config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + logging_policy, + ] + + +def _build_pipeline(config, transport=None, **kwargs): + # type: (Configuration, HttpTransport, **Any) -> Pipeline + policies = _get_policies(config) + if transport is None: + transport = RequestsTransport(**kwargs) + return Pipeline(transport, policies=policies) + + +class KeyVaultClientBase(object): def __init__(self, vault_url, credential, **kwargs): # type: (str, TokenCredential, **Any) -> None if not credential: @@ -68,42 +64,43 @@ def __init__(self, vault_url, credential, **kwargs): raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=False) - - # pylint:disable=no-self-use - def _build_pipeline(self, config, transport, **kwargs): - # type: (Configuration, HttpTransport, **Any) -> Pipeline - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - transport = RequestsTransport(**kwargs) - - return Pipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = ChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self): # type: () -> str return self._vault_url + + def __enter__(self): + # type: () -> KeyVaultClientBase + self._client.__enter__() + return self + + def __exit__(self, *args): + # type: (*Any) -> None + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + self._client.__exit__() diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py new file mode 100644 index 0000000000000..8c8b343047fe8 --- /dev/null +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/_shared/multi_api.py @@ -0,0 +1,43 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from collections import namedtuple + +from ._generated.v7_0.version import VERSION as V7_0_VERSION +from ._generated.v2016_10_01.version import VERSION as V2016_10_01_VERSION + +SUPPORTED_VERSIONS = (V7_0_VERSION, V2016_10_01_VERSION) +DEFAULT_VERSION = V7_0_VERSION + +GeneratedApi = namedtuple("GeneratedApi", ("models", "client_cls", "config_cls")) + + +def load_generated_api(api_version, aio=False): + # type: (str, bool) -> GeneratedApi + api_version = api_version or DEFAULT_VERSION + if api_version == V7_0_VERSION: + from ._generated.v7_0 import models + + if aio: + from ._generated.v7_0.aio import KeyVaultClient + from ._generated.v7_0.aio._configuration_async import KeyVaultClientConfiguration + else: + from ._generated.v7_0 import KeyVaultClient # type: ignore + from ._generated.v7_0._configuration import KeyVaultClientConfiguration # type: ignore + elif api_version == V2016_10_01_VERSION: + from ._generated.v2016_10_01 import models # type: ignore + + if aio: + from ._generated.v2016_10_01.aio import KeyVaultClient # type: ignore + from ._generated.v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration # type: ignore + else: + from ._generated.v2016_10_01 import KeyVaultClient # type: ignore + from ._generated.v2016_10_01._configuration import KeyVaultClientConfiguration # type: ignore + else: + raise NotImplementedError( + "This package doesn't support API version '{}'. ".format(api_version) + + "Supported versions: {}".format(", ".join(SUPPORTED_VERSIONS)) + ) + + return GeneratedApi(models=models, client_cls=KeyVaultClient, config_cls=KeyVaultClientConfiguration) diff --git a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/aio/_client.py b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/aio/_client.py index f928f091ca387..eb671cbe8cf08 100644 --- a/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-certificates/azure/keyvault/certificates/aio/_client.py @@ -87,7 +87,7 @@ async def create_certificate( tags = kwargs.pop("tags", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None cert_bundle = await self._client.create_certificate( @@ -327,7 +327,7 @@ async def import_certificate( policy = kwargs.pop("policy", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None base64_encoded_certificate = base64.b64encode(certificate_bytes).decode("utf-8") @@ -410,7 +410,7 @@ async def update_certificate_properties( enabled = kwargs.pop("enabled", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None @@ -731,7 +731,7 @@ async def merge_certificate( enabled = kwargs.pop("enabled", None) if enabled is not None: - attributes = self._client.models.CertificateAttributes(enabled=enabled) + attributes = self._models.CertificateAttributes(enabled=enabled) else: attributes = None bundle = await self._client.merge_certificate( @@ -801,12 +801,12 @@ async def create_issuer(self, issuer_name: str, provider: str, **kwargs: "Any") admin_contacts = kwargs.pop("admin_contacts", None) if account_id or password: - issuer_credentials = self._client.models.IssuerCredentials(account_id=account_id, password=password) + issuer_credentials = self._models.IssuerCredentials(account_id=account_id, password=password) else: issuer_credentials = None if admin_contacts: admin_details = [ - self._client.models.AdministratorDetails( + self._models.AdministratorDetails( first_name=contact.first_name, last_name=contact.last_name, email_address=contact.email, @@ -817,13 +817,13 @@ async def create_issuer(self, issuer_name: str, provider: str, **kwargs: "Any") else: admin_details = None if organization_id or admin_details: - organization_details = self._client.models.OrganizationDetails( + organization_details = self._models.OrganizationDetails( id=organization_id, admin_details=admin_details ) else: organization_details = None if enabled is not None: - issuer_attributes = self._client.models.IssuerAttributes(enabled=enabled) + issuer_attributes = self._models.IssuerAttributes(enabled=enabled) else: issuer_attributes = None issuer_bundle = await self._client.set_certificate_issuer( @@ -864,12 +864,12 @@ async def update_issuer(self, issuer_name: str, **kwargs: "Any") -> CertificateI admin_contacts = kwargs.pop("admin_contacts", None) if account_id or password: - issuer_credentials = self._client.models.IssuerCredentials(account_id=account_id, password=password) + issuer_credentials = self._models.IssuerCredentials(account_id=account_id, password=password) else: issuer_credentials = None if admin_contacts: admin_details = list( - self._client.models.AdministratorDetails( + self._models.AdministratorDetails( first_name=contact.first_name, last_name=contact.last_name, email_address=contact.email, @@ -880,13 +880,13 @@ async def update_issuer(self, issuer_name: str, **kwargs: "Any") -> CertificateI else: admin_details = None if organization_id or admin_details: - organization_details = self._client.models.OrganizationDetails( + organization_details = self._models.OrganizationDetails( id=organization_id, admin_details=admin_details ) else: organization_details = None if enabled is not None: - issuer_attributes = self._client.models.IssuerAttributes(enabled=enabled) + issuer_attributes = self._models.IssuerAttributes(enabled=enabled) else: issuer_attributes = None issuer_bundle = await self._client.update_certificate_issuer( diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/backup_restore_operations_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/backup_restore_operations_async.py index 5b6beca79033b..6a21737b4ed92 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/backup_restore_operations_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/backup_restore_operations_async.py @@ -74,6 +74,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/contacts_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/contacts_async.py index abe4a6b7fcbc7..804a138af9bb7 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/contacts_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/contacts_async.py @@ -61,6 +61,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/hello_world_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/hello_world_async.py index c2239975ebfbc..753aadb90baab 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/hello_world_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/hello_world_async.py @@ -103,6 +103,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/issuers_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/issuers_async.py index dc228c2b30d2e..21dbe1f92b399 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/issuers_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/issuers_async.py @@ -95,6 +95,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/list_operations_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/list_operations_async.py index 1b5eebacf8f9c..75932b9c96050 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/list_operations_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/list_operations_async.py @@ -110,6 +110,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/samples/recover_purge_operations_async.py b/sdk/keyvault/azure-keyvault-certificates/samples/recover_purge_operations_async.py index 82fe3d5ceb130..f61c5dab62bc3 100644 --- a/sdk/keyvault/azure-keyvault-certificates/samples/recover_purge_operations_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/samples/recover_purge_operations_async.py @@ -92,6 +92,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-certificates/setup.py b/sdk/keyvault/azure-keyvault-certificates/setup.py index 7ad88f820b477..c0bba02f8bdc1 100644 --- a/sdk/keyvault/azure-keyvault-certificates/setup.py +++ b/sdk/keyvault/azure-keyvault-certificates/setup.py @@ -79,7 +79,7 @@ "azure.keyvault", ] ), - install_requires=["azure-core<2.0.0,>=1.2.1", "azure-common~=1.1", "msrest>=0.6.0"], + install_requires=["azure-core<2.0.0,>=1.2.1", "msrest>=0.6.0"], extras_require={ ":python_version<'3.0'": ["azure-keyvault-nspkg"], ":python_version<'3.4'": ["enum34>=1.0.4"], diff --git a/sdk/keyvault/azure-keyvault-certificates/tests/_shared/test_case_async.py b/sdk/keyvault/azure-keyvault-certificates/tests/_shared/test_case_async.py index 1522264fdce09..4462e8aa86675 100644 --- a/sdk/keyvault/azure-keyvault-certificates/tests/_shared/test_case_async.py +++ b/sdk/keyvault/azure-keyvault-certificates/tests/_shared/test_case_async.py @@ -35,7 +35,11 @@ def await_prepared_test(test_fn): @functools.wraps(test_fn) def run(test_class_instance, *args, **kwargs): loop = asyncio.get_event_loop() - return loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + client = kwargs.get("client") + result = loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + if client: + loop.run_until_complete(client.close()) + return result return run diff --git a/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager.py b/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager.py new file mode 100644 index 0000000000000..fbfdbc6afee4d --- /dev/null +++ b/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager.py @@ -0,0 +1,25 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.certificates import CertificateClient + +from _shared.helpers import mock + + +def test_close(): + transport = mock.MagicMock() + client = CertificateClient(vault_url="https://localhost", credential=object(), transport=transport) + client.close() + assert transport.__enter__.call_count == 0 + assert transport.__exit__.call_count == 1 + + +def test_context_manager(): + transport = mock.MagicMock() + client = CertificateClient(vault_url="https://localhost", credential=object(), transport=transport) + + with client: + assert transport.__enter__.call_count == 1 + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 diff --git a/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager_async.py b/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager_async.py new file mode 100644 index 0000000000000..8e4cfa7dd5e9a --- /dev/null +++ b/sdk/keyvault/azure-keyvault-certificates/tests/test_context_manager_async.py @@ -0,0 +1,29 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.certificates.aio import CertificateClient +import pytest + +from _shared.helpers_async import AsyncMockTransport + + +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + client = CertificateClient(vault_url="https://localhost", credential=object(), transport=transport) + + await client.close() + assert transport.__aenter__.call_count == 0 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + client = CertificateClient(vault_url="https://localhost", credential=object(), transport=transport) + + async with client: + assert transport.__aenter__.call_count == 1 + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 diff --git a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md index 2d28514ceba7b..1901bba054afa 100644 --- a/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-keys/CHANGELOG.md @@ -1,7 +1,9 @@ # Release History ## 4.0.2 (Unreleased) - +- `KeyClient` instances have a `close` method which closes opened sockets. Used +as a context manager, a `KeyClient` closes opened sockets on exit. +([#9906](https://github.com/Azure/azure-sdk-for-python/pull/9906)) ## 4.0.1 (2020-02-11) - `azure.keyvault.keys` defines `__version__` diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py index f94725b1b0b52..317cc609b659d 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py @@ -82,7 +82,7 @@ def create_key(self, name, key_type, **kwargs): not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None @@ -438,7 +438,7 @@ def update_key_properties(self, name, version=None, **kwargs): not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None bundle = self._client.update_key( @@ -529,7 +529,7 @@ def import_key(self, name, key, **kwargs): not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None bundle = self._client.import_key( diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/__init__.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/__init__.py index efc5f67755a24..b74cfa3b899cc 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/__init__.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/__init__.py @@ -2,6 +2,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from .key_vault_client import KeyVaultClient - -__all__ = ["KeyVaultClient"] diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/key_vault_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/key_vault_client.py deleted file mode 100644 index fb29f4364313b..0000000000000 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/_generated/key_vault_client.py +++ /dev/null @@ -1,199 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -from azure.profiles import KnownProfiles, ProfileDefinition -from azure.profiles.multiapiclient import MultiApiClientMixin - -from .v7_0.version import VERSION as V7_0_VERSION -from .v2016_10_01.version import VERSION as V2016_10_01_VERSION - - -class KeyVaultClient(MultiApiClientMixin): - """The key vault client performs cryptographic key operations and vault operations against the Key Vault service. - Implementation depends on the API version: - - * 2016-10-01: :class:`v2016_10_01.KeyVaultClient` - * 7.0: :class:`v7_0.KeyVaultClient` - - :param credentials: Credentials needed for the client to connect to Azure. - :type credentials: :mod:`A msrestazure Credentials - object` - - :param str api_version: API version to use if no profile is provided, or if - missing in profile. - :param profile: A profile definition, from KnownProfiles to dict. - :type profile: azure.profiles.KnownProfiles - """ - - DEFAULT_API_VERSION = V7_0_VERSION - _PROFILE_TAG = "azure.keyvault.KeyVaultClient" - LATEST_PROFILE = ProfileDefinition({_PROFILE_TAG: {None: DEFAULT_API_VERSION}}, _PROFILE_TAG + " latest") - - _init_complete = False - - def __init__(self, credentials, pipeline=None, api_version=None, aio=False, profile=KnownProfiles.default): - self._client_impls = {} - self._pipeline = pipeline - self._entered = False - self._aio = aio - super(KeyVaultClient, self).__init__(api_version=api_version, profile=profile) - - self._credentials = credentials - self._init_complete = True - - @staticmethod - def get_configuration_class(api_version, aio=False): - """ - Get the versioned configuration implementation corresponding to the current profile. - :return: The versioned configuration implementation. - """ - if api_version == V7_0_VERSION: - if aio: - from .v7_0.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v7_0._configuration import KeyVaultClientConfiguration as ImplConfig - elif api_version == V2016_10_01_VERSION: - if aio: - from .v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v2016_10_01._configuration import KeyVaultClientConfiguration as ImplConfig - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - return ImplConfig - - @property - def models(self): - """Module depends on the API version: - * 2016-10-01: :mod:`v2016_10_01.models` - * 7.0: :mod:`v7_0.models` - """ - api_version = self._get_api_version(None) - - if api_version == V7_0_VERSION: - from .v7_0 import models as impl_models - elif api_version == V2016_10_01_VERSION: - from .v2016_10_01 import models as impl_models - else: - raise NotImplementedError("APIVersion {} is not available".format(api_version)) - return impl_models - - def _get_client_impl(self): - """ - Get the versioned client implementation corresponding to the current profile. - :return: The versioned client implementation. - """ - api_version = self._get_api_version(None) - if api_version not in self._client_impls: - self._create_client_impl(api_version) - return self._client_impls[api_version] - - def _create_client_impl(self, api_version): - """ - Creates the client implementation corresponding to the specified api_version. - :param api_version: - :return: - """ - if api_version == V7_0_VERSION: - if self._aio: - from .v7_0.aio import KeyVaultClient as ImplClient - else: - from .v7_0 import KeyVaultClient as ImplClient - elif api_version == V2016_10_01_VERSION: - if self._aio: - from .v2016_10_01.aio import KeyVaultClient as ImplClient - else: - from .v2016_10_01 import KeyVaultClient as ImplClient - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - - impl = ImplClient(credentials=self._credentials, pipeline=self._pipeline) - - # if __enter__ has previously been called and the impl client has __enter__ defined we need to call it - if self._entered: - if hasattr(impl, "__enter__"): - impl.__enter__() - elif hasattr(impl, "__aenter__"): - impl.__aenter__() - - self._client_impls[api_version] = impl - return impl - - def __aenter__(self, *args, **kwargs): - """ - Calls __aenter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aenter__ - :param kwargs: keyword arguments to relay to client implementations of __aenter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aenter__"): - impl.__aenter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __aenter__ called on them as appropriate - self._entered = True - return self - - def __enter__(self, *args, **kwargs): - """ - Calls __enter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__enter__"): - impl.__enter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __enter__ called on them as appropriate - self._entered = True - return self - - def __aexit__(self, *args, **kwargs): - """ - Calls __aexit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aexit__ - :param kwargs: keyword arguments to relay to client implementations of __aexit__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aexit__"): - impl.__aexit__(*args, **kwargs) - return self - - def __exit__(self, *args, **kwargs): - """ - Calls __exit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__exit__"): - impl.__exit__(*args, **kwargs) - return self - - def __getattr__(self, name): - """ - In the case that the attribute is not defined on the custom KeyVaultClient. Attempt to get - the attribute from the versioned client implementation corresponding to the current profile. - :param name: Name of the attribute retrieve from the current versioned client implementation - :return: The value of the specified attribute on the current client implementation. - """ - impl = self._get_client_impl() - return getattr(impl, name) - - def __setattr__(self, name, attr): - """ - Sets the specified attribute either on the custom KeyVaultClient or the current underlying implementation. - :param name: Name of the attribute to set - :param attr: Value of the attribute to set - :return: None - """ - if self._init_complete and not hasattr(self, name): - impl = self._get_client_impl() - setattr(impl, name, attr) - else: - super(KeyVaultClient, self).__setattr__(name, attr) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py index 6f4a1e563077f..84c827a3a938e 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/async_client_base.py @@ -2,108 +2,80 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING -from azure.core.configuration import Configuration from azure.core.pipeline import AsyncPipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy -) -from azure.core.pipeline.transport import AsyncHttpTransport -from ._generated import KeyVaultClient from . import AsyncChallengeAuthPolicy -from .._user_agent import USER_AGENT - +from .client_base import _get_policies +from .multi_api import load_generated_api if TYPE_CHECKING: try: # pylint:disable=unused-import - from azure.core.credentials import TokenCredential + from typing import Any + from azure.core.configuration import Configuration + from azure.core.pipeline.transport import AsyncHttpTransport + from azure.core.credentials_async import AsyncTokenCredential except ImportError: - # TokenCredential is a typing_extensions.Protocol; we don't depend on that package + # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package pass -class AsyncKeyVaultClientBase: - """Base class for async Key Vault clients""" - - @staticmethod - def _create_config(credential: "TokenCredential", api_version: str = None, **kwargs: "Any") -> Configuration: - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs) - config.authentication_policy = AsyncChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config - - def __init__(self, vault_url: str, credential: "TokenCredential", **kwargs: "Any") -> None: +def _build_pipeline(config: "Configuration", transport: "AsyncHttpTransport" = None, **kwargs: "Any") -> AsyncPipeline: + policies = _get_policies(config, **kwargs) + if transport is None: + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline(transport, policies=policies) + + +class AsyncKeyVaultClientBase(object): + def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: if not credential: raise ValueError( - "credential should be an object supporting the TokenCredential protocol, " + "credential should be an object supporting the AsyncTokenCredential protocol, " "such as a credential from azure-identity" ) if not vault_url: raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=True) - - @staticmethod - def _build_pipeline(config: Configuration, transport: AsyncHttpTransport, **kwargs: "Any") -> AsyncPipeline: - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - from azure.core.pipeline.transport import AioHttpTransport - - transport = AioHttpTransport(**kwargs) - - return AsyncPipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version, aio=True) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = AsyncChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self) -> str: return self._vault_url + + async def __aenter__(self) -> "AsyncKeyVaultClientBase": + await self._client.__aenter__() + return self + + async def __aexit__(self, *args: "Any") -> None: + await self._client.__aexit__(*args) + + async def close(self) -> None: + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + await self._client.__aexit__() diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py index 1e1b7692d379a..b1e1a2e997d2e 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/client_base.py @@ -5,58 +5,54 @@ from typing import TYPE_CHECKING from azure.core.pipeline import Pipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + UserAgentPolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, ) from azure.core.pipeline.transport import RequestsTransport -from ._generated import KeyVaultClient + +from .multi_api import load_generated_api from .challenge_auth_policy import ChallengeAuthPolicy from .._user_agent import USER_AGENT if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Optional + from typing import Any from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpTransport from azure.core.configuration import Configuration -KEY_VAULT_SCOPE = "https://vault.azure.net/.default" +def _get_policies(config, **kwargs): + logging_policy = HttpLoggingPolicy(**kwargs) + logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") -class KeyVaultClientBase(object): - """Base class for Key Vault clients""" - - @staticmethod - def _create_config(credential, api_version=None, **kwargs): - # type: (TokenCredential, Optional[str], **Any) -> Configuration - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs) - config.authentication_policy = ChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config + return [ + config.headers_policy, + UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs), + config.proxy_policy, + ContentDecodePolicy(), + config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + logging_policy, + ] + + +def _build_pipeline(config, transport=None, **kwargs): + # type: (Configuration, HttpTransport, **Any) -> Pipeline + policies = _get_policies(config) + if transport is None: + transport = RequestsTransport(**kwargs) + return Pipeline(transport, policies=policies) + + +class KeyVaultClientBase(object): def __init__(self, vault_url, credential, **kwargs): # type: (str, TokenCredential, **Any) -> None if not credential: @@ -68,42 +64,43 @@ def __init__(self, vault_url, credential, **kwargs): raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=False) - - # pylint:disable=no-self-use - def _build_pipeline(self, config, transport, **kwargs): - # type: (Configuration, HttpTransport, **Any) -> Pipeline - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - transport = RequestsTransport(**kwargs) - - return Pipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = ChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self): # type: () -> str return self._vault_url + + def __enter__(self): + # type: () -> KeyVaultClientBase + self._client.__enter__() + return self + + def __exit__(self, *args): + # type: (*Any) -> None + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + self._client.__exit__() diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/multi_api.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/multi_api.py new file mode 100644 index 0000000000000..8c8b343047fe8 --- /dev/null +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_shared/multi_api.py @@ -0,0 +1,43 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from collections import namedtuple + +from ._generated.v7_0.version import VERSION as V7_0_VERSION +from ._generated.v2016_10_01.version import VERSION as V2016_10_01_VERSION + +SUPPORTED_VERSIONS = (V7_0_VERSION, V2016_10_01_VERSION) +DEFAULT_VERSION = V7_0_VERSION + +GeneratedApi = namedtuple("GeneratedApi", ("models", "client_cls", "config_cls")) + + +def load_generated_api(api_version, aio=False): + # type: (str, bool) -> GeneratedApi + api_version = api_version or DEFAULT_VERSION + if api_version == V7_0_VERSION: + from ._generated.v7_0 import models + + if aio: + from ._generated.v7_0.aio import KeyVaultClient + from ._generated.v7_0.aio._configuration_async import KeyVaultClientConfiguration + else: + from ._generated.v7_0 import KeyVaultClient # type: ignore + from ._generated.v7_0._configuration import KeyVaultClientConfiguration # type: ignore + elif api_version == V2016_10_01_VERSION: + from ._generated.v2016_10_01 import models # type: ignore + + if aio: + from ._generated.v2016_10_01.aio import KeyVaultClient # type: ignore + from ._generated.v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration # type: ignore + else: + from ._generated.v2016_10_01 import KeyVaultClient # type: ignore + from ._generated.v2016_10_01._configuration import KeyVaultClientConfiguration # type: ignore + else: + raise NotImplementedError( + "This package doesn't support API version '{}'. ".format(api_version) + + "Supported versions: {}".format(", ".join(SUPPORTED_VERSIONS)) + ) + + return GeneratedApi(models=models, client_cls=KeyVaultClient, config_cls=KeyVaultClientConfiguration) diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py index dadc05e363191..45cbdbab30694 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/aio/_client.py @@ -81,7 +81,7 @@ async def create_key(self, name: str, key_type: "Union[str, KeyType]", **kwargs: expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None @@ -409,7 +409,7 @@ async def update_key_properties(self, name: str, version: "Optional[str]" = None not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None bundle = await self._client.update_key( @@ -497,7 +497,7 @@ async def import_key(self, name: str, key: JsonWebKey, **kwargs: "Any") -> KeyVa not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) + attributes = self._models.KeyAttributes(enabled=enabled, not_before=not_before, expires=expires_on) else: attributes = None bundle = await self._client.import_key( diff --git a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py index 82e19f5629935..899904e7e783b 100644 --- a/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py +++ b/sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_client.py @@ -64,7 +64,7 @@ def __init__(self, key, credential, **kwargs): self._key = key self._key_id = parse_vault_id(key.id) self._allowed_ops = frozenset(self._key.key_operations) - elif isinstance(key, six.text_type): + elif isinstance(key, six.string_types): self._key = None self._key_id = parse_vault_id(key) self._keys_get_forbidden = None # type: Optional[bool] diff --git a/sdk/keyvault/azure-keyvault-keys/samples/backup_restore_operations_async.py b/sdk/keyvault/azure-keyvault-keys/samples/backup_restore_operations_async.py index 1c9b05cf94c4a..e7fd5f876e0ad 100644 --- a/sdk/keyvault/azure-keyvault-keys/samples/backup_restore_operations_async.py +++ b/sdk/keyvault/azure-keyvault-keys/samples/backup_restore_operations_async.py @@ -63,6 +63,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-keys/samples/hello_world_async.py b/sdk/keyvault/azure-keyvault-keys/samples/hello_world_async.py index c8c1e204bbd98..335682213eded 100644 --- a/sdk/keyvault/azure-keyvault-keys/samples/hello_world_async.py +++ b/sdk/keyvault/azure-keyvault-keys/samples/hello_world_async.py @@ -92,6 +92,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-keys/samples/list_operations_async.py b/sdk/keyvault/azure-keyvault-keys/samples/list_operations_async.py index 1929644d31b15..527e5a49e4f3e 100644 --- a/sdk/keyvault/azure-keyvault-keys/samples/list_operations_async.py +++ b/sdk/keyvault/azure-keyvault-keys/samples/list_operations_async.py @@ -95,6 +95,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-keys/samples/recover_purge_operations_async.py b/sdk/keyvault/azure-keyvault-keys/samples/recover_purge_operations_async.py index c1a0b1250426a..caf6275a53200 100644 --- a/sdk/keyvault/azure-keyvault-keys/samples/recover_purge_operations_async.py +++ b/sdk/keyvault/azure-keyvault-keys/samples/recover_purge_operations_async.py @@ -75,6 +75,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-keys/setup.py b/sdk/keyvault/azure-keyvault-keys/setup.py index 785e7dfd112ce..96b7f303066e3 100644 --- a/sdk/keyvault/azure-keyvault-keys/setup.py +++ b/sdk/keyvault/azure-keyvault-keys/setup.py @@ -79,7 +79,7 @@ "azure.keyvault", ] ), - install_requires=["azure-core<2.0.0,>=1.2.1", "azure-common~=1.1", "cryptography>=2.1.4", "msrest>=0.6.0"], + install_requires=["azure-core<2.0.0,>=1.2.1", "cryptography>=2.1.4", "msrest>=0.6.0"], extras_require={ ":python_version<'3.0'": ["azure-keyvault-nspkg"], ":python_version<'3.4'": ["enum34>=1.0.4"], diff --git a/sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py b/sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py index 1522264fdce09..4462e8aa86675 100644 --- a/sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py +++ b/sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py @@ -35,7 +35,11 @@ def await_prepared_test(test_fn): @functools.wraps(test_fn) def run(test_class_instance, *args, **kwargs): loop = asyncio.get_event_loop() - return loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + client = kwargs.get("client") + result = loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + if client: + loop.run_until_complete(client.close()) + return result return run diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager.py b/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager.py new file mode 100644 index 0000000000000..398096411549b --- /dev/null +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager.py @@ -0,0 +1,44 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.keys import KeyClient +from azure.keyvault.keys.crypto import CryptographyClient + +from _shared.helpers import mock + + +def test_key_client_close(): + transport = mock.MagicMock() + client = KeyClient(vault_url="https://localhost", credential=object(), transport=transport) + client.close() + assert transport.__enter__.call_count == 0 + assert transport.__exit__.call_count == 1 + + +def test_key_client_context_manager(): + transport = mock.MagicMock() + client = KeyClient(vault_url="https://localhost", credential=object(), transport=transport) + + with client: + assert transport.__enter__.call_count == 1 + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 + + +def test_crypto_client_close(): + transport = mock.MagicMock() + client = CryptographyClient(key="https://localhost/a/b/c", credential=object(), transport=transport) + client.close() + assert transport.__enter__.call_count == 0 + assert transport.__exit__.call_count == 1 + + +def test_crypto_client_context_manager(): + transport = mock.MagicMock() + client = CryptographyClient(key="https://localhost/a/b/c", credential=object(), transport=transport) + + with client: + assert transport.__enter__.call_count == 1 + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 diff --git a/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager_async.py b/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager_async.py new file mode 100644 index 0000000000000..6a564097a704a --- /dev/null +++ b/sdk/keyvault/azure-keyvault-keys/tests/test_context_manager_async.py @@ -0,0 +1,49 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.keys.aio import KeyClient +from azure.keyvault.keys.crypto.aio import CryptographyClient +import pytest + +from _shared.helpers_async import AsyncMockTransport + + +@pytest.mark.asyncio +async def test_key_client_close(): + transport = AsyncMockTransport() + client = KeyClient(vault_url="https://localhost", credential=object(), transport=transport) + + await client.close() + assert transport.__aenter__.call_count == 0 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_key_client_context_manager(): + transport = AsyncMockTransport() + client = KeyClient(vault_url="https://localhost", credential=object(), transport=transport) + + async with client: + assert transport.__aenter__.call_count == 1 + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_crypto_client_close(): + transport = AsyncMockTransport() + client = CryptographyClient(key="https://localhost/a/b/c", credential=object(), transport=transport) + await client.close() + assert transport.__aenter__.call_count == 0 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_crypto_client_context_manager(): + transport = AsyncMockTransport() + client = CryptographyClient(key="https://localhost/a/b/c", credential=object(), transport=transport) + async with client: + assert transport.__aenter__.call_count == 1 + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1 diff --git a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md index b4aa41d607a4d..6e1231d772df6 100644 --- a/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md +++ b/sdk/keyvault/azure-keyvault-secrets/CHANGELOG.md @@ -1,7 +1,9 @@ # Release History ## 4.0.2 (Unreleased) - +- `SecretClient` instances have a `close` method which closes opened sockets. +Used as a context manager, a `SecretClient` closes opened sockets on exit. +([#9906](https://github.com/Azure/azure-sdk-for-python/pull/9906)) ## 4.0.1 (2020-02-11) - `azure.keyvault.secrets` defines `__version__` diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_client.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_client.py index d09bd4351d0cb..b8207e3e5f442 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_client.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_client.py @@ -103,7 +103,7 @@ def set_secret(self, name, value, **kwargs): not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.SecretAttributes( + attributes = self._models.SecretAttributes( enabled=enabled, not_before=not_before, expires=expires_on ) else: @@ -152,7 +152,7 @@ def update_secret_properties(self, name, version=None, **kwargs): not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.SecretAttributes( + attributes = self._models.SecretAttributes( enabled=enabled, not_before=not_before, expires=expires_on ) else: diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/__init__.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/__init__.py index efc5f67755a24..b74cfa3b899cc 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/__init__.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/__init__.py @@ -2,6 +2,3 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from .key_vault_client import KeyVaultClient - -__all__ = ["KeyVaultClient"] diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/key_vault_client.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/key_vault_client.py deleted file mode 100644 index fb29f4364313b..0000000000000 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/_generated/key_vault_client.py +++ /dev/null @@ -1,199 +0,0 @@ -# ------------------------------------ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# ------------------------------------ -from azure.profiles import KnownProfiles, ProfileDefinition -from azure.profiles.multiapiclient import MultiApiClientMixin - -from .v7_0.version import VERSION as V7_0_VERSION -from .v2016_10_01.version import VERSION as V2016_10_01_VERSION - - -class KeyVaultClient(MultiApiClientMixin): - """The key vault client performs cryptographic key operations and vault operations against the Key Vault service. - Implementation depends on the API version: - - * 2016-10-01: :class:`v2016_10_01.KeyVaultClient` - * 7.0: :class:`v7_0.KeyVaultClient` - - :param credentials: Credentials needed for the client to connect to Azure. - :type credentials: :mod:`A msrestazure Credentials - object` - - :param str api_version: API version to use if no profile is provided, or if - missing in profile. - :param profile: A profile definition, from KnownProfiles to dict. - :type profile: azure.profiles.KnownProfiles - """ - - DEFAULT_API_VERSION = V7_0_VERSION - _PROFILE_TAG = "azure.keyvault.KeyVaultClient" - LATEST_PROFILE = ProfileDefinition({_PROFILE_TAG: {None: DEFAULT_API_VERSION}}, _PROFILE_TAG + " latest") - - _init_complete = False - - def __init__(self, credentials, pipeline=None, api_version=None, aio=False, profile=KnownProfiles.default): - self._client_impls = {} - self._pipeline = pipeline - self._entered = False - self._aio = aio - super(KeyVaultClient, self).__init__(api_version=api_version, profile=profile) - - self._credentials = credentials - self._init_complete = True - - @staticmethod - def get_configuration_class(api_version, aio=False): - """ - Get the versioned configuration implementation corresponding to the current profile. - :return: The versioned configuration implementation. - """ - if api_version == V7_0_VERSION: - if aio: - from .v7_0.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v7_0._configuration import KeyVaultClientConfiguration as ImplConfig - elif api_version == V2016_10_01_VERSION: - if aio: - from .v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration as ImplConfig - else: - from .v2016_10_01._configuration import KeyVaultClientConfiguration as ImplConfig - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - return ImplConfig - - @property - def models(self): - """Module depends on the API version: - * 2016-10-01: :mod:`v2016_10_01.models` - * 7.0: :mod:`v7_0.models` - """ - api_version = self._get_api_version(None) - - if api_version == V7_0_VERSION: - from .v7_0 import models as impl_models - elif api_version == V2016_10_01_VERSION: - from .v2016_10_01 import models as impl_models - else: - raise NotImplementedError("APIVersion {} is not available".format(api_version)) - return impl_models - - def _get_client_impl(self): - """ - Get the versioned client implementation corresponding to the current profile. - :return: The versioned client implementation. - """ - api_version = self._get_api_version(None) - if api_version not in self._client_impls: - self._create_client_impl(api_version) - return self._client_impls[api_version] - - def _create_client_impl(self, api_version): - """ - Creates the client implementation corresponding to the specified api_version. - :param api_version: - :return: - """ - if api_version == V7_0_VERSION: - if self._aio: - from .v7_0.aio import KeyVaultClient as ImplClient - else: - from .v7_0 import KeyVaultClient as ImplClient - elif api_version == V2016_10_01_VERSION: - if self._aio: - from .v2016_10_01.aio import KeyVaultClient as ImplClient - else: - from .v2016_10_01 import KeyVaultClient as ImplClient - else: - raise NotImplementedError("API version {} is not available".format(api_version)) - - impl = ImplClient(credentials=self._credentials, pipeline=self._pipeline) - - # if __enter__ has previously been called and the impl client has __enter__ defined we need to call it - if self._entered: - if hasattr(impl, "__enter__"): - impl.__enter__() - elif hasattr(impl, "__aenter__"): - impl.__aenter__() - - self._client_impls[api_version] = impl - return impl - - def __aenter__(self, *args, **kwargs): - """ - Calls __aenter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aenter__ - :param kwargs: keyword arguments to relay to client implementations of __aenter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aenter__"): - impl.__aenter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __aenter__ called on them as appropriate - self._entered = True - return self - - def __enter__(self, *args, **kwargs): - """ - Calls __enter__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__enter__"): - impl.__enter__(*args, **kwargs) - - # mark the current KeyVaultClient as _entered so that client implementations instantiated - # subsequently will also have __enter__ called on them as appropriate - self._entered = True - return self - - def __aexit__(self, *args, **kwargs): - """ - Calls __aexit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __aexit__ - :param kwargs: keyword arguments to relay to client implementations of __aexit__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__aexit__"): - impl.__aexit__(*args, **kwargs) - return self - - def __exit__(self, *args, **kwargs): - """ - Calls __exit__ on all client implementations which support it - :param args: positional arguments to relay to client implementations of __enter__ - :param kwargs: keyword arguments to relay to client implementations of __enter__ - :return: returns the current KeyVaultClient instance - """ - for _, impl in self._client_impls.items(): - if hasattr(impl, "__exit__"): - impl.__exit__(*args, **kwargs) - return self - - def __getattr__(self, name): - """ - In the case that the attribute is not defined on the custom KeyVaultClient. Attempt to get - the attribute from the versioned client implementation corresponding to the current profile. - :param name: Name of the attribute retrieve from the current versioned client implementation - :return: The value of the specified attribute on the current client implementation. - """ - impl = self._get_client_impl() - return getattr(impl, name) - - def __setattr__(self, name, attr): - """ - Sets the specified attribute either on the custom KeyVaultClient or the current underlying implementation. - :param name: Name of the attribute to set - :param attr: Value of the attribute to set - :return: None - """ - if self._init_complete and not hasattr(self, name): - impl = self._get_client_impl() - setattr(impl, name, attr) - else: - super(KeyVaultClient, self).__setattr__(name, attr) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_client_base.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_client_base.py index 6f4a1e563077f..84c827a3a938e 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_client_base.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/async_client_base.py @@ -2,108 +2,80 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING -from azure.core.configuration import Configuration from azure.core.pipeline import AsyncPipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy -) -from azure.core.pipeline.transport import AsyncHttpTransport -from ._generated import KeyVaultClient from . import AsyncChallengeAuthPolicy -from .._user_agent import USER_AGENT - +from .client_base import _get_policies +from .multi_api import load_generated_api if TYPE_CHECKING: try: # pylint:disable=unused-import - from azure.core.credentials import TokenCredential + from typing import Any + from azure.core.configuration import Configuration + from azure.core.pipeline.transport import AsyncHttpTransport + from azure.core.credentials_async import AsyncTokenCredential except ImportError: - # TokenCredential is a typing_extensions.Protocol; we don't depend on that package + # AsyncTokenCredential is a typing_extensions.Protocol; we don't depend on that package pass -class AsyncKeyVaultClientBase: - """Base class for async Key Vault clients""" - - @staticmethod - def _create_config(credential: "TokenCredential", api_version: str = None, **kwargs: "Any") -> Configuration: - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=True)(credential, **kwargs) - config.authentication_policy = AsyncChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config - - def __init__(self, vault_url: str, credential: "TokenCredential", **kwargs: "Any") -> None: +def _build_pipeline(config: "Configuration", transport: "AsyncHttpTransport" = None, **kwargs: "Any") -> AsyncPipeline: + policies = _get_policies(config, **kwargs) + if transport is None: + from azure.core.pipeline.transport import AioHttpTransport + + transport = AioHttpTransport(**kwargs) + + return AsyncPipeline(transport, policies=policies) + + +class AsyncKeyVaultClientBase(object): + def __init__(self, vault_url: str, credential: "AsyncTokenCredential", **kwargs: "Any") -> None: if not credential: raise ValueError( - "credential should be an object supporting the TokenCredential protocol, " + "credential should be an object supporting the AsyncTokenCredential protocol, " "such as a credential from azure-identity" ) if not vault_url: raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=True) - - @staticmethod - def _build_pipeline(config: Configuration, transport: AsyncHttpTransport, **kwargs: "Any") -> AsyncPipeline: - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - from azure.core.pipeline.transport import AioHttpTransport - - transport = AioHttpTransport(**kwargs) - - return AsyncPipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version, aio=True) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = AsyncChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self) -> str: return self._vault_url + + async def __aenter__(self) -> "AsyncKeyVaultClientBase": + await self._client.__aenter__() + return self + + async def __aexit__(self, *args: "Any") -> None: + await self._client.__aexit__(*args) + + async def close(self) -> None: + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + await self._client.__aexit__() diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/client_base.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/client_base.py index 1e1b7692d379a..b1e1a2e997d2e 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/client_base.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/client_base.py @@ -5,58 +5,54 @@ from typing import TYPE_CHECKING from azure.core.pipeline import Pipeline -from azure.core.pipeline.policies import( - ContentDecodePolicy, UserAgentPolicy, DistributedTracingPolicy, HttpLoggingPolicy +from azure.core.pipeline.policies import ( + ContentDecodePolicy, + UserAgentPolicy, + DistributedTracingPolicy, + HttpLoggingPolicy, ) from azure.core.pipeline.transport import RequestsTransport -from ._generated import KeyVaultClient + +from .multi_api import load_generated_api from .challenge_auth_policy import ChallengeAuthPolicy from .._user_agent import USER_AGENT if TYPE_CHECKING: # pylint:disable=unused-import - from typing import Any, Optional + from typing import Any from azure.core.credentials import TokenCredential from azure.core.pipeline.transport import HttpTransport from azure.core.configuration import Configuration -KEY_VAULT_SCOPE = "https://vault.azure.net/.default" +def _get_policies(config, **kwargs): + logging_policy = HttpLoggingPolicy(**kwargs) + logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") -class KeyVaultClientBase(object): - """Base class for Key Vault clients""" - - @staticmethod - def _create_config(credential, api_version=None, **kwargs): - # type: (TokenCredential, Optional[str], **Any) -> Configuration - if api_version is None: - api_version = KeyVaultClient.DEFAULT_API_VERSION - config = KeyVaultClient.get_configuration_class(api_version, aio=False)(credential, **kwargs) - config.authentication_policy = ChallengeAuthPolicy(credential) - - # replace the autorest-generated UserAgentPolicy and its hard-coded user agent - # https://github.com/Azure/azure-sdk-for-python/issues/6637 - config.user_agent_policy = UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs) - - # Override config policies if found in kwargs - # TODO: should be unnecessary after next regeneration (written 2019-08-02) - if "user_agent_policy" in kwargs: - config.user_agent_policy = kwargs["user_agent_policy"] - if "headers_policy" in kwargs: - config.headers_policy = kwargs["headers_policy"] - if "proxy_policy" in kwargs: - config.proxy_policy = kwargs["proxy_policy"] - if "logging_policy" in kwargs: - config.logging_policy = kwargs["logging_policy"] - if "retry_policy" in kwargs: - config.retry_policy = kwargs["retry_policy"] - if "custom_hook_policy" in kwargs: - config.custom_hook_policy = kwargs["custom_hook_policy"] - if "redirect_policy" in kwargs: - config.redirect_policy = kwargs["redirect_policy"] - - return config + return [ + config.headers_policy, + UserAgentPolicy(base_user_agent=USER_AGENT, **kwargs), + config.proxy_policy, + ContentDecodePolicy(), + config.redirect_policy, + config.retry_policy, + config.authentication_policy, + config.logging_policy, + DistributedTracingPolicy(**kwargs), + logging_policy, + ] + + +def _build_pipeline(config, transport=None, **kwargs): + # type: (Configuration, HttpTransport, **Any) -> Pipeline + policies = _get_policies(config) + if transport is None: + transport = RequestsTransport(**kwargs) + return Pipeline(transport, policies=policies) + + +class KeyVaultClientBase(object): def __init__(self, vault_url, credential, **kwargs): # type: (str, TokenCredential, **Any) -> None if not credential: @@ -68,42 +64,43 @@ def __init__(self, vault_url, credential, **kwargs): raise ValueError("vault_url must be the URL of an Azure Key Vault") self._vault_url = vault_url.strip(" /") - client = kwargs.get("generated_client") if client: # caller provided a configured client -> nothing left to initialize self._client = client return - config = self._create_config(credential, **kwargs) - transport = kwargs.pop("transport", None) - pipeline = kwargs.pop("pipeline", None) or self._build_pipeline(config, transport=transport, **kwargs) - self._client = KeyVaultClient(credential, pipeline=pipeline, aio=False) - - # pylint:disable=no-self-use - def _build_pipeline(self, config, transport, **kwargs): - # type: (Configuration, HttpTransport, **Any) -> Pipeline - logging_policy = HttpLoggingPolicy(**kwargs) - logging_policy.allowed_header_names.add("x-ms-keyvault-network-info") - policies = [ - config.headers_policy, - config.user_agent_policy, - config.proxy_policy, - ContentDecodePolicy(), - config.redirect_policy, - config.retry_policy, - config.authentication_policy, - config.logging_policy, - DistributedTracingPolicy(**kwargs), - logging_policy, - ] - - if transport is None: - transport = RequestsTransport(**kwargs) - - return Pipeline(transport, policies=policies) + api_version = kwargs.pop("api_version", None) + generated = load_generated_api(api_version) + + pipeline = kwargs.pop("pipeline", None) + if not pipeline: + config = generated.config_cls(credential, **kwargs) + config.authentication_policy = ChallengeAuthPolicy(credential) + pipeline = _build_pipeline(config, **kwargs) + + # generated clients don't use their credentials parameter + self._client = generated.client_cls(credentials="", pipeline=pipeline) + self._models = generated.models @property def vault_url(self): # type: () -> str return self._vault_url + + def __enter__(self): + # type: () -> KeyVaultClientBase + self._client.__enter__() + return self + + def __exit__(self, *args): + # type: (*Any) -> None + self._client.__exit__(*args) + + def close(self): + # type: () -> None + """Close sockets opened by the client. + + Calling this method is unnecessary when using the client as a context manager. + """ + self._client.__exit__() diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/multi_api.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/multi_api.py new file mode 100644 index 0000000000000..8c8b343047fe8 --- /dev/null +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/_shared/multi_api.py @@ -0,0 +1,43 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from collections import namedtuple + +from ._generated.v7_0.version import VERSION as V7_0_VERSION +from ._generated.v2016_10_01.version import VERSION as V2016_10_01_VERSION + +SUPPORTED_VERSIONS = (V7_0_VERSION, V2016_10_01_VERSION) +DEFAULT_VERSION = V7_0_VERSION + +GeneratedApi = namedtuple("GeneratedApi", ("models", "client_cls", "config_cls")) + + +def load_generated_api(api_version, aio=False): + # type: (str, bool) -> GeneratedApi + api_version = api_version or DEFAULT_VERSION + if api_version == V7_0_VERSION: + from ._generated.v7_0 import models + + if aio: + from ._generated.v7_0.aio import KeyVaultClient + from ._generated.v7_0.aio._configuration_async import KeyVaultClientConfiguration + else: + from ._generated.v7_0 import KeyVaultClient # type: ignore + from ._generated.v7_0._configuration import KeyVaultClientConfiguration # type: ignore + elif api_version == V2016_10_01_VERSION: + from ._generated.v2016_10_01 import models # type: ignore + + if aio: + from ._generated.v2016_10_01.aio import KeyVaultClient # type: ignore + from ._generated.v2016_10_01.aio._configuration_async import KeyVaultClientConfiguration # type: ignore + else: + from ._generated.v2016_10_01 import KeyVaultClient # type: ignore + from ._generated.v2016_10_01._configuration import KeyVaultClientConfiguration # type: ignore + else: + raise NotImplementedError( + "This package doesn't support API version '{}'. ".format(api_version) + + "Supported versions: {}".format(", ".join(SUPPORTED_VERSIONS)) + ) + + return GeneratedApi(models=models, client_cls=KeyVaultClient, config_cls=KeyVaultClientConfiguration) diff --git a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/aio/_client.py b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/aio/_client.py index 7d0df561c7a0b..0ce594042349d 100644 --- a/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/aio/_client.py +++ b/sdk/keyvault/azure-keyvault-secrets/azure/keyvault/secrets/aio/_client.py @@ -88,7 +88,7 @@ async def set_secret(self, name: str, value: str, **kwargs: "Any") -> KeyVaultSe not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.SecretAttributes( + attributes = self._models.SecretAttributes( enabled=enabled, not_before=not_before, expires=expires_on ) else: @@ -132,7 +132,7 @@ async def update_secret_properties( not_before = kwargs.pop("not_before", None) expires_on = kwargs.pop("expires_on", None) if enabled is not None or not_before is not None or expires_on is not None: - attributes = self._client.models.SecretAttributes( + attributes = self._models.SecretAttributes( enabled=enabled, not_before=not_before, expires=expires_on ) else: diff --git a/sdk/keyvault/azure-keyvault-secrets/samples/backup_restore_operations_async.py b/sdk/keyvault/azure-keyvault-secrets/samples/backup_restore_operations_async.py index f7fdc6d798a80..171ffba38046e 100644 --- a/sdk/keyvault/azure-keyvault-secrets/samples/backup_restore_operations_async.py +++ b/sdk/keyvault/azure-keyvault-secrets/samples/backup_restore_operations_async.py @@ -65,6 +65,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-secrets/samples/hello_world_async.py b/sdk/keyvault/azure-keyvault-secrets/samples/hello_world_async.py index 795021749e83e..21437433cbaee 100644 --- a/sdk/keyvault/azure-keyvault-secrets/samples/hello_world_async.py +++ b/sdk/keyvault/azure-keyvault-secrets/samples/hello_world_async.py @@ -85,6 +85,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-secrets/samples/list_operations_async.py b/sdk/keyvault/azure-keyvault-secrets/samples/list_operations_async.py index 96af51284483e..4b52797527b14 100644 --- a/sdk/keyvault/azure-keyvault-secrets/samples/list_operations_async.py +++ b/sdk/keyvault/azure-keyvault-secrets/samples/list_operations_async.py @@ -99,6 +99,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-secrets/samples/recover_purge_operations_async.py b/sdk/keyvault/azure-keyvault-secrets/samples/recover_purge_operations_async.py index 9c424d89c448b..8b5c5adf42706 100644 --- a/sdk/keyvault/azure-keyvault-secrets/samples/recover_purge_operations_async.py +++ b/sdk/keyvault/azure-keyvault-secrets/samples/recover_purge_operations_async.py @@ -77,6 +77,8 @@ async def run_sample(): finally: print("\nrun_sample done") + await credential.close() + await client.close() if __name__ == "__main__": diff --git a/sdk/keyvault/azure-keyvault-secrets/setup.py b/sdk/keyvault/azure-keyvault-secrets/setup.py index 4ad6cb2d5b753..6eabf6bf8874c 100644 --- a/sdk/keyvault/azure-keyvault-secrets/setup.py +++ b/sdk/keyvault/azure-keyvault-secrets/setup.py @@ -79,7 +79,7 @@ "azure.keyvault", ] ), - install_requires=["azure-core<2.0.0,>=1.2.1", "azure-common~=1.1", "msrest>=0.6.0"], + install_requires=["azure-core<2.0.0,>=1.2.1", "msrest>=0.6.0"], extras_require={ ":python_version<'3.0'": ["azure-keyvault-nspkg"], ":python_version<'3.4'": ["enum34>=1.0.4"], diff --git a/sdk/keyvault/azure-keyvault-secrets/tests/_shared/test_case_async.py b/sdk/keyvault/azure-keyvault-secrets/tests/_shared/test_case_async.py index 1522264fdce09..4462e8aa86675 100644 --- a/sdk/keyvault/azure-keyvault-secrets/tests/_shared/test_case_async.py +++ b/sdk/keyvault/azure-keyvault-secrets/tests/_shared/test_case_async.py @@ -35,7 +35,11 @@ def await_prepared_test(test_fn): @functools.wraps(test_fn) def run(test_class_instance, *args, **kwargs): loop = asyncio.get_event_loop() - return loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + client = kwargs.get("client") + result = loop.run_until_complete(test_fn(test_class_instance, *args, **kwargs)) + if client: + loop.run_until_complete(client.close()) + return result return run diff --git a/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager.py b/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager.py new file mode 100644 index 0000000000000..5d31db1e7946b --- /dev/null +++ b/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager.py @@ -0,0 +1,25 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.secrets import SecretClient + +from _shared.helpers import mock + + +def test_close(): + transport = mock.MagicMock() + client = SecretClient(vault_url="https://localhost", credential=object(), transport=transport) + client.close() + assert transport.__enter__.call_count == 0 + assert transport.__exit__.call_count == 1 + + +def test_context_manager(): + transport = mock.MagicMock() + client = SecretClient(vault_url="https://localhost", credential=object(), transport=transport) + + with client: + assert transport.__enter__.call_count == 1 + assert transport.__enter__.call_count == 1 + assert transport.__exit__.call_count == 1 diff --git a/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager_async.py b/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager_async.py new file mode 100644 index 0000000000000..ed01bc121b261 --- /dev/null +++ b/sdk/keyvault/azure-keyvault-secrets/tests/test_context_manager_async.py @@ -0,0 +1,29 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +from azure.keyvault.secrets.aio import SecretClient +import pytest + +from _shared.helpers_async import AsyncMockTransport + + +@pytest.mark.asyncio +async def test_close(): + transport = AsyncMockTransport() + client = SecretClient(vault_url="https://localhost", credential=object(), transport=transport) + + await client.close() + assert transport.__aenter__.call_count == 0 + assert transport.__aexit__.call_count == 1 + + +@pytest.mark.asyncio +async def test_context_manager(): + transport = AsyncMockTransport() + client = SecretClient(vault_url="https://localhost", credential=object(), transport=transport) + + async with client: + assert transport.__aenter__.call_count == 1 + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 1