Skip to content

Commit

Permalink
Merge pull request Azure#11 from kristapratico/pipeline_ownership
Browse files Browse the repository at this point in the history
pipeline ownership for queues and files
  • Loading branch information
kristapratico authored Oct 17, 2019
2 parents 952c125 + 6c73352 commit e09e3ed
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from azure.core import Configuration
from azure.core.exceptions import HttpResponseError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import RequestsTransport
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.policies import RedirectPolicy, ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy

Expand Down Expand Up @@ -216,6 +216,27 @@ def _batch_send(
process_storage_error(error)


class TransportWrapper(HttpTransport):

def __init__(self, transport):
self._transport = transport

def send(self, request, **kwargs):
return self._transport.send(request, **kwargs)

def open(self):
pass

def close(self):
pass

def __enter__(self, *args): # pylint: disable=arguments-differ
pass

def __exit__(self, *args): # pylint: disable=arguments-differ
pass


def format_shared_key_credential(account, credential):
if isinstance(credential, six.string_types):
if len(account) < 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AsyncBearerTokenCredentialPolicy,
AsyncRedirectPolicy)

from azure.core.pipeline.transport import AsyncHttpTransport
from .constants import STORAGE_OAUTH_SCOPE, DEFAULT_SOCKET_TIMEOUT
from .authentication import SharedKeyCredentialPolicy
from .base_client import create_configuration
Expand Down Expand Up @@ -122,3 +123,24 @@ async def _batch_send(
return response.parts() # Return an AsyncIterator
except StorageErrorException as error:
process_storage_error(error)


class AsyncTransportWrapper(AsyncHttpTransport):

def __init__(self, async_transport):
self._transport = async_transport

async def send(self, request, **kwargs):
return await self._transport.send(request, **kwargs)

async def open(self):
pass

async def close(self):
pass

async def __aenter__(self, *args): # pylint: disable=arguments-differ
pass

async def __aexit__(self, *args): # pylint: disable=arguments-differ
pass
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from azure.core.polling import async_poller
from azure.core.async_paging import AsyncItemPaged

from azure.core.pipeline import AsyncPipeline
from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async
from .._parser import _get_file_permission, _datetime_to_str
Expand All @@ -20,7 +20,7 @@
from .._generated.aio import AzureFileStorage
from .._generated.version import VERSION
from .._generated.models import StorageErrorException
from .._shared.base_client_async import AsyncStorageAccountHostsMixin
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
from .._shared.policies_async import ExponentialRetry
from .._shared.request_handlers import add_metadata_headers
from .._shared.response_handlers import return_response_headers, process_storage_error
Expand Down Expand Up @@ -112,10 +112,15 @@ def get_file_client(self, file_name, **kwargs):
"""
if self.directory_path:
file_name = self.directory_path.rstrip('/') + "/" + file_name

_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return FileClient(
self.url, file_path=file_name, share_name=self.share_name, snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)

def get_subdirectory_client(self, directory_name, **kwargs):
# type: (str, Any) -> DirectoryClient
Expand All @@ -138,10 +143,15 @@ def get_subdirectory_client(self, directory_name, **kwargs):
:caption: Gets the subdirectory client.
"""
directory_path = self.directory_path.rstrip('/') + "/" + directory_name

_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return DirectoryClient(
self.url, share_name=self.share_name, directory_path=directory_path, snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop, **kwargs)

@distributed_trace_async
async def create_directory(self, **kwargs): # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

from azure.core.async_paging import AsyncItemPaged
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline import AsyncPipeline
from azure.core.tracing.decorator_async import distributed_trace_async

from .._shared.base_client_async import AsyncStorageAccountHostsMixin
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
from .._shared.response_handlers import process_storage_error
from .._shared.policies_async import ExponentialRetry
from .._generated.aio import AzureFileStorage
Expand Down Expand Up @@ -314,6 +315,11 @@ def get_share_client(self, share, snapshot=None):
share_name = share.name
except AttributeError:
share_name = share

_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return ShareClient(
self.url, share_name=share_name, snapshot=snapshot, credential=self.credential, _hosts=self._hosts,
_configuration=self._config, _pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop)
_configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from azure.core.tracing.decorator import distributed_trace
from azure.core.tracing.decorator_async import distributed_trace_async

from azure.core.pipeline import AsyncPipeline
from .._shared.policies_async import ExponentialRetry
from .._shared.base_client_async import AsyncStorageAccountHostsMixin
from .._shared.base_client_async import AsyncStorageAccountHostsMixin, AsyncTransportWrapper
from .._shared.request_handlers import add_metadata_headers, serialize_iso
from .._shared.response_handlers import (
return_response_headers,
Expand Down Expand Up @@ -102,9 +102,14 @@ def get_directory_client(self, directory_path=None):
:returns: A Directory Client.
:rtype: ~azure.storage.file.aio.directory_client_async.DirectoryClient
"""
_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)

return DirectoryClient(
self.url, share_name=self.share_name, directory_path=directory_path or "", snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
_location_mode=self._location_mode, loop=self._loop)

def get_file_client(self, file_path):
Expand All @@ -117,10 +122,15 @@ def get_file_client(self, file_path):
:returns: A File Client.
:rtype: ~azure.storage.file.aio.file_client_async.FileClient
"""
_pipeline = AsyncPipeline(
transport=AsyncTransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)

return FileClient(
self.url, share_name=self.share_name, file_path=file_path, snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
_pipeline=self._pipeline, _location_mode=self._location_mode, loop=self._loop)
_pipeline=_pipeline, _location_mode=self._location_mode, loop=self._loop)

@distributed_trace_async
async def create_share(self, **kwargs): # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
import six
from azure.core.polling import LROPoller
from azure.core.paging import ItemPaged
from azure.core.pipeline import Pipeline
from azure.core.tracing.decorator import distributed_trace

from ._generated import AzureFileStorage
from ._generated.version import VERSION
from ._generated.models import StorageErrorException
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.request_handlers import add_metadata_headers
from ._shared.response_handlers import return_response_headers, process_storage_error
from ._shared.parser import _str
Expand Down Expand Up @@ -217,10 +218,15 @@ def get_file_client(self, file_name, **kwargs):
"""
if self.directory_path:
file_name = self.directory_path.rstrip('/') + "/" + file_name

_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return FileClient(
self.url, file_path=file_name, share_name=self.share_name, napshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
_pipeline=self._pipeline, _location_mode=self._location_mode, **kwargs)
_pipeline=_pipeline, _location_mode=self._location_mode, **kwargs)

def get_subdirectory_client(self, directory_name, **kwargs):
# type: (str, Any) -> DirectoryClient
Expand All @@ -243,9 +249,14 @@ def get_subdirectory_client(self, directory_name, **kwargs):
:caption: Gets the subdirectory client.
"""
directory_path = self.directory_path.rstrip('/') + "/" + directory_name

_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return DirectoryClient(
self.url, share_name=self.share_name, directory_path=directory_path, snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
_location_mode=self._location_mode, **kwargs)

@distributed_trace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

from azure.core.paging import ItemPaged
from azure.core.tracing.decorator import distributed_trace

from azure.core.pipeline import Pipeline
from ._shared.shared_access_signature import SharedAccessSignature
from ._shared.models import Services
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.response_handlers import process_storage_error
from ._generated import AzureFileStorage
from ._generated.models import StorageErrorException, StorageServiceProperties
Expand Down Expand Up @@ -431,6 +431,11 @@ def get_share_client(self, share, snapshot=None):
share_name = share.name
except AttributeError:
share_name = share

_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)
return ShareClient(
self.url, share_name=share_name, snapshot=snapshot, credential=self.credential, _hosts=self._hosts,
_configuration=self._config, _pipeline=self._pipeline, _location_mode=self._location_mode)
_configuration=self._config, _pipeline=_pipeline, _location_mode=self._location_mode)
17 changes: 14 additions & 3 deletions sdk/storage/azure-storage-file/azure/storage/file/share_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import six
from azure.core.tracing.decorator import distributed_trace
from ._shared.base_client import StorageAccountHostsMixin, parse_connection_str, parse_query
from azure.core.pipeline import Pipeline
from ._shared.base_client import StorageAccountHostsMixin, TransportWrapper, parse_connection_str, parse_query
from ._shared.request_handlers import add_metadata_headers, serialize_iso
from ._shared.response_handlers import (
return_response_headers,
Expand Down Expand Up @@ -297,9 +298,14 @@ def get_directory_client(self, directory_path=None):
:returns: A Directory Client.
:rtype: ~azure.storage.file.DirectoryClient
"""
_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)

return DirectoryClient(
self.url, share_name=self.share_name, directory_path=directory_path or "", snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=self._pipeline,
credential=self.credential, _hosts=self._hosts, _configuration=self._config, _pipeline=_pipeline,
_location_mode=self._location_mode)

def get_file_client(self, file_path):
Expand All @@ -312,10 +318,15 @@ def get_file_client(self, file_path):
:returns: A File Client.
:rtype: ~azure.storage.file.FileClient
"""
_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)

return FileClient(
self.url, share_name=self.share_name, file_path=file_path, snapshot=self.snapshot,
credential=self.credential, _hosts=self._hosts, _configuration=self._config,
_pipeline=self._pipeline, _location_mode=self._location_mode)
_pipeline=_pipeline, _location_mode=self._location_mode)

@distributed_trace
def create_share(self, **kwargs): # type: ignore
Expand Down
13 changes: 13 additions & 0 deletions sdk/storage/azure-storage-file/tests/test_share.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import pytest
import requests
from azure.core.pipeline.transport import RequestsTransport
from azure.core.exceptions import (
HttpResponseError,
ResourceNotFoundError,
Expand Down Expand Up @@ -762,6 +763,18 @@ def test_create_permission_for_share(self):
# server returned permission
self.assertEquals(permission_key, permission_key2)

@record
def test_transport_closed_only_once(self):
transport = RequestsTransport()
url = self.get_file_url()
credential = self.get_shared_key_credential()
share = self._get_share_reference()
with FileServiceClient(url, credential=credential, transport=transport) as fsc:
assert transport.session is not None
with fsc.get_share_client(share.share_name) as fc:
assert transport.session is not None
assert transport.session is not None # Right now it's None

# ------------------------------------------------------------------------------
if __name__ == '__main__':
unittest.main()
12 changes: 12 additions & 0 deletions sdk/storage/azure-storage-file/tests/test_share_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
import requests
from azure.core.pipeline.transport import AioHttpTransport
from azure.core.pipeline.transport import AsyncioRequestsTransport
from multidict import CIMultiDict, CIMultiDictProxy
from azure.core.exceptions import (
HttpResponseError,
Expand Down Expand Up @@ -918,6 +919,17 @@ def test_create_permission_for_share_async(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self._test_create_permission_for_share())

async def test_transport_closed_only_once_async(self):
transport = AsyncioRequestsTransport()
url = self.get_file_url()
credential = self.get_shared_key_credential()
share = self._get_share_reference()
async with FileServiceClient(url, credential=credential, transport=transport) as fsc:
assert transport.session is not None
async with fsc.get_share_client(share.share_name) as fc:
assert transport.session is not None
assert transport.session is not None # Right now it's None

# ------------------------------------------------------------------------------
if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from azure.core import Configuration
from azure.core.exceptions import HttpResponseError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import RequestsTransport
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
from azure.core.pipeline.policies.distributed_tracing import DistributedTracingPolicy
from azure.core.pipeline.policies import RedirectPolicy, ContentDecodePolicy, BearerTokenCredentialPolicy, ProxyPolicy

Expand Down Expand Up @@ -216,6 +216,27 @@ def _batch_send(
process_storage_error(error)


class TransportWrapper(HttpTransport):

def __init__(self, transport):
self._transport = transport

def send(self, request, **kwargs):
return self._transport.send(request, **kwargs)

def open(self):
pass

def close(self):
pass

def __enter__(self, *args): # pylint: disable=arguments-differ
pass

def __exit__(self, *args): # pylint: disable=arguments-differ
pass


def format_shared_key_credential(account, credential):
if isinstance(credential, six.string_types):
if len(account) < 2:
Expand Down
Loading

0 comments on commit e09e3ed

Please sign in to comment.