Skip to content

Commit

Permalink
copy shared code
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell committed Feb 26, 2020
1 parent e49004b commit c236792
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ._shared import KeyVaultClientBase
from ._shared.exceptions import error_map as _error_map
from ._shared._polling import DeletePollingMethod, RecoverDeletedPollingMethod, KeyVaultOperationPoller
from ._shared._polling import DeleteRecoverPollingMethod, KeyVaultOperationPoller
from ._models import (
KeyVaultCertificate,
CertificateProperties,
Expand Down Expand Up @@ -222,17 +222,16 @@ def begin_delete_certificate(self, certificate_name, **kwargs):
vault_base_url=self.vault_url, certificate_name=certificate_name, error_map=_error_map, **kwargs
)
deleted_cert = DeletedCertificate._from_deleted_certificate_bundle(deleted_cert_bundle)
sd_disabled = deleted_cert.recovery_id is None
command = partial(self.get_deleted_certificate, certificate_name=certificate_name, **kwargs)
delete_cert_polling_method = DeletePollingMethod(
command=command,

polling_method = DeleteRecoverPollingMethod(
# no recovery ID means soft-delete is disabled, in which case we initialize the poller as finished
finished=deleted_cert.recovery_id is None,
command=partial(self.get_deleted_certificate, certificate_name=certificate_name, **kwargs),
final_resource=deleted_cert,
initial_status="deleting",
finished_status="deleted",
sd_disabled=sd_disabled,
interval=polling_interval,
)
return KeyVaultOperationPoller(delete_cert_polling_method)

return KeyVaultOperationPoller(polling_method)

@distributed_trace
def get_deleted_certificate(self, certificate_name, **kwargs):
Expand Down Expand Up @@ -315,19 +314,18 @@ def begin_recover_deleted_certificate(self, certificate_name, **kwargs):
polling_interval = kwargs.pop("_polling_interval", None)
if polling_interval is None:
polling_interval = 2

recovered_cert_bundle = self._client.recover_deleted_certificate(
vault_base_url=self.vault_url, certificate_name=certificate_name, error_map=_error_map, **kwargs
)
recovered_certificate = KeyVaultCertificate._from_certificate_bundle(recovered_cert_bundle)
command = partial(self.get_certificate, certificate_name=certificate_name, **kwargs)
recover_cert_polling_method = RecoverDeletedPollingMethod(
command=command,
final_resource=recovered_certificate,
initial_status="recovering",
finished_status="recovered",
interval=polling_interval,
polling_method = DeleteRecoverPollingMethod(
finished=False, command=command, final_resource=recovered_certificate, interval=polling_interval
)
return KeyVaultOperationPoller(recover_cert_polling_method)

return KeyVaultOperationPoller(polling_method)


@distributed_trace
def import_certificate(self, certificate_name, certificate_bytes, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError

try:
from urlparse import urlparse # type: ignore # pylint: disable=unused-import
from urlparse import urlparse # type: ignore # pylint: disable=unused-import
except ImportError:
from urllib.parse import urlparse

Expand Down Expand Up @@ -61,8 +61,8 @@ def wait(self, timeout=None):
if not self._polling_method.finished():
self._done = threading.Event()
self._thread = threading.Thread(
target=with_current_context(self._start),
name="KeyVaultOperationPoller({})".format(uuid.uuid4()))
target=with_current_context(self._start), name="KeyVaultOperationPoller({})".format(uuid.uuid4())
)
self._thread.daemon = True
self._thread.start()

Expand All @@ -72,29 +72,41 @@ def wait(self, timeout=None):
try:
# Let's handle possible None in forgiveness here
raise self._exception # type: ignore
except TypeError: # Was None
except TypeError: # Was None
pass

class RecoverDeletedPollingMethod(PollingMethod):
def __init__(self, command, final_resource, initial_status, finished_status, interval=2):

class DeleteRecoverPollingMethod(PollingMethod):
"""Poller for deleting resources, and recovering deleted resources, in vaults with soft-delete enabled.
This works by polling for the existence of the deleted or recovered resource. When a resource is deleted, Key Vault
immediately removes it from its collection. However, the resource will not immediately appear in the deleted
collection. Key Vault will therefore respond 404 to GET requests for the deleted resource; when it responds 2xx
or 403, the resource exists in the deleted collection, i.e. its deletion is complete.
Similarly, while recovering a deleted resource, Key Vault will respond 404 to GET requests for the non-deleted
resource; when it responds 2xx or 403, the resource exists in the non-deleted collection, i.e. its rec+
(403 indicates completion of these operations because Key Vault responds 403 when a resource exists but the client
lacks permission to access it.)
"""
def __init__(self, command, final_resource, finished, interval=2):
self._command = command
self._resource = final_resource
self._polling_interval = interval
self._status = initial_status
self._finished_status = finished_status
self._finished = finished

def _update_status(self):
# type: () -> None
try:
self._command()
self._status = self._finished_status
self._finished = True
except ResourceNotFoundError:
pass
except HttpResponseError as e:
# If we are polling on get_deleted_* and we don't have get permissions, we will get
# ResourceNotFoundError until the resource is recovered, at which point we'll get a 403.
if e.status_code == 403:
self._status = self._finished_status
self._finished = True
else:
raise

Expand All @@ -106,35 +118,20 @@ def run(self):
try:
while not self.finished():
self._update_status()
time.sleep(self._polling_interval)
if not self.finished():
time.sleep(self._polling_interval)
except Exception as e:
logger.warning(str(e))
raise

def finished(self):
# type: () -> bool
return self._status == self._finished_status
return self._finished

def resource(self):
# type: () -> Any
return self._resource

def status(self):
# type: () -> str
return self._status


class DeletePollingMethod(RecoverDeletedPollingMethod):
def __init__(self, command, final_resource, initial_status, finished_status, sd_disabled, interval=2):
self._sd_disabled = sd_disabled
super(DeletePollingMethod, self).__init__(
command=command,
final_resource=final_resource,
initial_status=initial_status,
finished_status=finished_status,
interval=interval
)

def finished(self):
# type: () -> bool
return self._sd_disabled or self._status == self._finished_status
return "finished" if self._finished else "polling"
Original file line number Diff line number Diff line change
Expand Up @@ -9,61 +9,66 @@

from azure.core.polling import AsyncPollingMethod
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError

if TYPE_CHECKING:
# pylint:disable=ungrouped-imports
from typing import Any, Callable, Union

logger = logging.getLogger(__name__)

class RecoverDeletedAsyncPollingMethod(AsyncPollingMethod):
def __init__(self, initial_status, finished_status, interval=2):
self._command = None
self._resource = None

class AsyncDeleteRecoverPollingMethod(AsyncPollingMethod):
"""Poller for deleting resources, and recovering deleted resources, in vaults with soft-delete enabled.
This works by polling for the existence of the deleted or recovered resource. When a resource is deleted, Key Vault
immediately removes it from its collection. However, the resource will not immediately appear in the deleted
collection. Key Vault will therefore respond 404 to GET requests for the deleted resource; when it responds 2xx
or 403, the resource exists in the deleted collection, i.e. its deletion is complete.
Similarly, while recovering a deleted resource, Key Vault will respond 404 to GET requests for the non-deleted
resource; when it responds 2xx or 403, the resource exists in the non-deleted collection, i.e. its rec+
(403 indicates completion of these operations because Key Vault responds 403 when a resource exists but the client
lacks permission to access it.)
"""

def __init__(self, command, final_resource, finished, interval=2):
self._command = command
self._resource = final_resource
self._polling_interval = interval
self._status = initial_status
self._finished_status = finished_status
self._finished = finished

def initialize(self, client, initial_response, deserialization_callback):
pass

async def _update_status(self) -> None:
try:
await self._command()
self._status = self._finished_status
self._finished = True
except ResourceNotFoundError:
pass
except HttpResponseError as e:
# If we are polling on get_deleted_* and we don't have get permissions, we will get
# ResourceNotFoundError until the resource is recovered, at which point we'll get a 403.
if e.status_code == 403:
self._status = self._finished_status
self._finished = True
else:
raise

def initialize(self, client: "Any", initial_response: str, _: "Callable") -> None:
self._command = client
self._resource = initial_response

async def run(self) -> None:
try:
while not self.finished():
await self._update_status()
await asyncio.sleep(self._polling_interval)
if not self.finished():
await asyncio.sleep(self._polling_interval)
except Exception as e:
logger.warning(str(e))
raise

def finished(self) -> bool:
return self._status == self._finished_status
return self._finished

def resource(self) -> "Any":
return self._resource

def status(self) -> str:
return self._status


class DeleteAsyncPollingMethod(RecoverDeletedAsyncPollingMethod):
def __init__(self, initial_status, finished_status, sd_disabled, interval=2):
self._sd_disabled = sd_disabled
super(DeleteAsyncPollingMethod, self).__init__(initial_status, finished_status, interval)

def finished(self) -> bool:
return self._sd_disabled or self._status == self._finished_status
return "finished" if self._finished else "polling"
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from ._polling_async import CreateCertificatePollerAsync
from .._shared import AsyncKeyVaultClientBase
from .._shared._polling_async import DeleteAsyncPollingMethod, RecoverDeletedAsyncPollingMethod
from .._shared._polling_async import AsyncDeleteRecoverPollingMethod
from .._shared.exceptions import error_map as _error_map


Expand Down Expand Up @@ -204,13 +204,17 @@ async def delete_certificate(self, certificate_name: str, **kwargs: "Any") -> De
vault_base_url=self.vault_url, certificate_name=certificate_name, error_map=_error_map, **kwargs
)
deleted_certificate = DeletedCertificate._from_deleted_certificate_bundle(deleted_cert_bundle)
sd_disabled = deleted_certificate.recovery_id is None
command = partial(self.get_deleted_certificate, certificate_name=certificate_name, **kwargs)

delete_certificate_poller = DeleteAsyncPollingMethod(
initial_status="deleting", finished_status="deleted", sd_disabled=sd_disabled, interval=polling_interval
polling_method = AsyncDeleteRecoverPollingMethod(
# no recovery ID means soft-delete is disabled, in which case we initialize the poller as finished
finished=deleted_certificate.recovery_id is None,
command=partial(self.get_deleted_certificate, certificate_name=certificate_name, **kwargs),
final_resource=deleted_certificate,
interval=polling_interval,
)
return await async_poller(command, deleted_certificate, None, delete_certificate_poller)
await polling_method.run()

return polling_method.resource()

@distributed_trace_async
async def get_deleted_certificate(self, certificate_name: str, **kwargs: "Any") -> DeletedCertificate:
Expand Down Expand Up @@ -289,12 +293,14 @@ async def recover_deleted_certificate(self, certificate_name: str, **kwargs: "An
vault_base_url=self.vault_url, certificate_name=certificate_name, error_map=_error_map, **kwargs
)
recovered_certificate = KeyVaultCertificate._from_certificate_bundle(recovered_cert_bundle)
command = partial(self.get_certificate, certificate_name=certificate_name, **kwargs)

recover_cert_poller = RecoverDeletedAsyncPollingMethod(
initial_status="recovering", finished_status="recovered", interval=polling_interval
command = partial(self.get_certificate, certificate_name=certificate_name, **kwargs)
polling_method = AsyncDeleteRecoverPollingMethod(
command=command, final_resource=recovered_certificate, finished=False, interval=polling_interval
)
return await async_poller(command, recovered_certificate, None, recover_cert_poller)
await polling_method.run()

return polling_method.resource()

@distributed_trace_async
async def import_certificate(
Expand Down
25 changes: 11 additions & 14 deletions sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ._shared import KeyVaultClientBase
from ._shared.exceptions import error_map as _error_map
from ._shared._polling import DeletePollingMethod, RecoverDeletedPollingMethod, KeyVaultOperationPoller
from ._shared._polling import DeleteRecoverPollingMethod, KeyVaultOperationPoller
from ._models import KeyVaultKey, KeyProperties, DeletedKey

try:
Expand Down Expand Up @@ -198,17 +198,16 @@ def begin_delete_key(self, name, **kwargs):
deleted_key = DeletedKey._from_deleted_key_bundle(
self._client.delete_key(self.vault_url, name, error_map=_error_map, **kwargs)
)
sd_disabled = deleted_key.recovery_id is None

command = partial(self.get_deleted_key, name=name, **kwargs)
delete_key_polling_method = DeletePollingMethod(
polling_method = DeleteRecoverPollingMethod(
# no recovery ID means soft-delete is disabled, in which case we initialize the poller as finished
finished=deleted_key.recovery_id is None,
command=command,
final_resource=deleted_key,
initial_status="deleting",
finished_status="deleted",
sd_disabled=sd_disabled,
interval=polling_interval,
)
return KeyVaultOperationPoller(delete_key_polling_method)
return KeyVaultOperationPoller(polling_method)

@distributed_trace
def get_key(self, name, version=None, **kwargs):
Expand Down Expand Up @@ -397,14 +396,12 @@ def begin_recover_deleted_key(self, name, **kwargs):
)
)
command = partial(self.get_key, name=name, **kwargs)
recover_key_polling_method = RecoverDeletedPollingMethod(
command=command,
final_resource=recovered_key,
initial_status="recovering",
finished_status="recovered",
interval=polling_interval,
polling_method = DeleteRecoverPollingMethod(
finished=False, command=command, final_resource=recovered_key, interval=polling_interval,
)
return KeyVaultOperationPoller(recover_key_polling_method)

return KeyVaultOperationPoller(polling_method)


@distributed_trace
def update_key_properties(self, name, version=None, **kwargs):
Expand Down
Loading

0 comments on commit c236792

Please sign in to comment.