Skip to content

Commit

Permalink
Feature/sas token final (#13140)
Browse files Browse the repository at this point in the history
Note this PR replaces the previous [Feature/sas token
merge](#12877) because the original
PR branch got jacked up beyond repair. All the comments on the earlier
PR are responded to there and addressed in the code for this one.

This PR is to enable `hail-az/https` Azure file references to contain
SAS tokens to enable bearer-auth style file access to Azure storage.
Basic summary of the changes:
- Update `AzureAsyncFS` url parsing function to look for and separate
out a SAS-token-like query string. Note: made fairly specific to SAS
tokens - generic detection of query string syntax interferes with glob
support and '?' characters in file names
- Added `generate_sas_token` convenience function to `AzureAsyncFS`.
Adds new azure-mgmt-storage package requirement.
- Updated `AzureAsyncFS` to use `(account, container, credential)` tuple
as internal `BlobServiceClient` cache key
- Updated `AzureAsyncFSURL` and `AzureFileListEntry` to track the token
separately from the name, and extend the base classes to allow returning
url with or without a token
- Update `RouterFS.ls` function and associated listfiles function to
allow for trailing query strings during path traversal
- Update `AsyncFS.open_from` function to handle query-string urls in
zero-length case
- Change to existing behavior: `LocalAsyncFSURL.__str__` no longer
returns 'file:' prefix. Done to make `str()` output be appropriate for
input to `fs` functions across all subclasses
- Updated `InputResource` to not include the SAS token as part of the
destination file name
- Updated `inter_cloud/test_fs.py` to generically use
query-string-friendly file path building functions to respect the new
model, where it is no longer safe to extend URLs by just appending new
segments with `+ "/"` because there may be a query string, and added
`'sas/azure-https'` test case to the fixture. Running tests for the SAS
case requires some new test variables to allow the test code to generate
SAS tokens (`build.yaml/test_hail_python_fs`):
```
      # Required for SAS testing on Azure
      export HAIL_TEST_AZURE_RESGRP=haildev
      export HAIL_TEST_AZURE_SUBID=12ab51c6-da79-4a99-8dec-3d2decc97343
```

---------

Co-authored-by: Greg Smith <gregsmi@microsoft.com>
Co-authored-by: Daniel Goldstein <danielgold95@gmail.com>
  • Loading branch information
3 people authored Jun 8, 2023
1 parent 1a26393 commit d647b6a
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 185 deletions.
3 changes: 3 additions & 0 deletions build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,9 @@ steps:
export HAIL_TEST_AZURE_ACCOUNT=hailtest
export HAIL_TEST_AZURE_CONTAINER=hail-test-4nxei
# Required for SAS testing on Azure
export HAIL_TEST_AZURE_RESGRP=hail-dev
export HAIL_TEST_AZURE_SUBID=22cd45fe-f996-4c51-af67-ef329d977519
export AZURE_APPLICATION_CREDENTIALS=/test-azure-key/credentials.json
python3 -m pytest \
Expand Down
7 changes: 4 additions & 3 deletions hail/python/hailtop/aiocloud/aioaws/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,6 @@ def name(self) -> str:
async def url(self) -> str:
return f's3://{self._bucket}/{self._key}'

def url_maybe_trailing_slash(self) -> str:
return f's3://{self._bucket}/{self._key}'

async def is_file(self) -> bool:
return self._item is not None

Expand Down Expand Up @@ -284,6 +281,10 @@ def bucket_parts(self) -> List[str]:
def path(self) -> str:
return self._path

@property
def query(self) -> Optional[str]:
return None

@property
def scheme(self) -> str:
return 's3'
Expand Down
98 changes: 75 additions & 23 deletions hail/python/hailtop/aiocloud/aioazure/fs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Any, AsyncContextManager, AsyncIterator, Dict, List, Optional, Set, Tuple, Type
from typing import Any, AsyncContextManager, AsyncIterator, Dict, List, Optional, Set, Tuple, Type, Union
from types import TracebackType

import abc
import re
import asyncio
from functools import wraps
import secrets
import logging
import datetime
from datetime import datetime, timedelta

from azure.storage.blob import BlobProperties
from azure.mgmt.storage.aio import StorageManagementClient
from azure.storage.blob import BlobProperties, ResourceTypes, generate_account_sas
from azure.storage.blob.aio import BlobClient, ContainerClient, BlobServiceClient, StorageStreamDownloader
from azure.storage.blob.aio._list_blobs_helper import BlobPrefix
import azure.core.exceptions
Expand Down Expand Up @@ -228,9 +230,9 @@ def name(self) -> str:
return self._url.path

async def url(self) -> str:
return str(self._url)
return self._url.base

def url_maybe_trailing_slash(self) -> str:
async def url_full(self) -> str:
return str(self._url)

async def is_file(self) -> bool:
Expand All @@ -256,25 +258,26 @@ async def size(self) -> int:
assert isinstance(size, int)
return size

def time_created(self) -> datetime.datetime:
def time_created(self) -> datetime:
ct = self.blob_props.creation_time
assert isinstance(ct, datetime.datetime)
assert isinstance(ct, datetime)
return ct

def time_modified(self) -> datetime.datetime:
def time_modified(self) -> datetime:
lm = self.blob_props.last_modified
assert isinstance(lm, datetime.datetime)
assert isinstance(lm, datetime)
return lm

async def __getitem__(self, key: str) -> Any:
return self.blob_props.__dict__[key]


class AzureAsyncFSURL(AsyncFSURL):
def __init__(self, account: str, container: str, path: str):
def __init__(self, account: str, container: str, path: str, query: Optional[str]):
self._account = account
self._container = container
self._path = path
self._query = query

@property
def bucket_parts(self) -> List[str]:
Expand All @@ -292,16 +295,29 @@ def account(self) -> str:
def container(self) -> str:
return self._container

@property
def query(self) -> Optional[str]:
return self._query

@property
@abc.abstractmethod
def base(self) -> str:
pass

def with_path(self, path) -> 'AzureAsyncFSURL':
return self.__class__(self._account, self._container, path)
return self.__class__(self._account, self._container, path, self._query)

def __str__(self) -> str:
return self.base if not self._query else f'{self.base}?{self._query}'


class AzureAsyncFSHailAzURL(AzureAsyncFSURL):
@property
def scheme(self) -> str:
return 'hail-az'

def __str__(self) -> str:
@property
def base(self) -> str:
return f'hail-az://{self._account}/{self._container}/{self._path}'


Expand All @@ -310,7 +326,8 @@ class AzureAsyncFSHttpsURL(AzureAsyncFSURL):
def scheme(self) -> str:
return 'https'

def __str__(self) -> str:
@property
def base(self) -> str:
return f'https://{self._account}.blob.core.windows.net/{self._container}/{self._path}'


Expand All @@ -319,13 +336,13 @@ def __str__(self) -> str:
# that container going forward.
def handle_public_access_error(fun):
@wraps(fun)
async def wrapped(self, url, *args, **kwargs):
async def wrapped(self: 'AzureAsyncFS', url, *args, **kwargs):
try:
return await fun(self, url, *args, **kwargs)
except azure.core.exceptions.ClientAuthenticationError:
fs_url = self.parse_url(url)
anon_client = BlobServiceClient(f'https://{fs_url.account}.blob.core.windows.net', credential=None)
self._blob_service_clients[(fs_url.account, fs_url.container)] = anon_client
self._blob_service_clients[(fs_url.account, fs_url.container, fs_url.query)] = anon_client
return await fun(self, url, *args, **kwargs)
return wrapped

Expand All @@ -345,7 +362,7 @@ def __init__(self, *, credential_file: Optional[str] = None, credentials: Option
raise ValueError('credential and credential_file cannot both be defined')

self._credential = credentials.credential
self._blob_service_clients: Dict[Tuple[str, str], BlobServiceClient] = {}
self._blob_service_clients: Dict[Tuple[str, str, Union[AzureCredentials, str, None]], BlobServiceClient] = {}

@staticmethod
def valid_url(url: str) -> bool:
Expand All @@ -358,6 +375,26 @@ def valid_url(url: str) -> bool:
return suffix == 'blob.core.windows.net'
return url.startswith('hail-az://')

async def generate_sas_token(
self,
subscription_id: str,
resource_group: str,
account: str,
permissions: str = "rw",
valid_interval: timedelta = timedelta(hours=1)
) -> str:
mgmt_client = StorageManagementClient(self._credential, subscription_id)
storage_keys = await mgmt_client.storage_accounts.list_keys(resource_group, account)
storage_key = storage_keys.keys[0].value

token = generate_account_sas(
account,
storage_key,
resource_types=ResourceTypes(container=True, object=True),
permission=permissions,
expiry=datetime.utcnow() + valid_interval)
return token

def parse_url(self, url: str) -> AzureAsyncFSURL:
colon_index = url.find(':')
if colon_index == -1:
Expand Down Expand Up @@ -386,27 +423,42 @@ def parse_url(self, url: str) -> AzureAsyncFSURL:
assert name[0] == '/'
name = name[1:]

name, token = AzureAsyncFS.get_name_parts(name)

if scheme == 'hail-az':
account = authority
return AzureAsyncFSHailAzURL(account, container, name)
return AzureAsyncFSHailAzURL(account, container, name, token)

assert scheme == 'https'
assert len(authority) > len('.blob.core.windows.net')
account = authority[:-len('.blob.core.windows.net')]
return AzureAsyncFSHttpsURL(account, container, name)
return AzureAsyncFSHttpsURL(account, container, name, token)

def get_blob_service_client(self, account: str, container: str) -> BlobServiceClient:
k = account, container
@staticmethod
def get_name_parts(name: str) -> Tuple[str, str]:
# Look for a terminating SAS token.
query_index = name.rfind('?')
if query_index != -1:
query_string = name[query_index + 1:]
first_kv_pair = query_string.split('&')[0].split('=')
# We will accept it as a token string if it begins with at least 1 key-value pair of the form 'k=v'.
if len(first_kv_pair) == 2 and all(s != '' for s in first_kv_pair):
return (name[:query_index], query_string)
return (name, '')

def get_blob_service_client(self, account: str, container: str, token: Optional[str]) -> BlobServiceClient:
credential = token if token else self._credential
k = account, container, token
if k not in self._blob_service_clients:
self._blob_service_clients[k] = BlobServiceClient(f'https://{account}.blob.core.windows.net', credential=self._credential)
self._blob_service_clients[k] = BlobServiceClient(f'https://{account}.blob.core.windows.net', credential=credential)
return self._blob_service_clients[k]

def get_blob_client(self, url: AzureAsyncFSURL) -> BlobClient:
blob_service_client = self.get_blob_service_client(url.account, url.container)
blob_service_client = self.get_blob_service_client(url.account, url.container, url.query)
return blob_service_client.get_blob_client(url.container, url.path)

def get_container_client(self, url: AzureAsyncFSURL) -> ContainerClient:
return self.get_blob_service_client(url.account, url.container).get_container_client(url.container)
return self.get_blob_service_client(url.account, url.container, url.query).get_container_client(url.container)

@handle_public_access_error
async def open(self, url: str) -> ReadableStream:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,6 @@ def name(self) -> str:
async def url(self) -> str:
return f'gs://{self._bucket}/{self._name}'

def url_maybe_trailing_slash(self) -> str:
return f'gs://{self._bucket}/{self._name}'

async def is_file(self) -> bool:
return self._items is not None

Expand Down Expand Up @@ -570,6 +567,10 @@ def bucket_parts(self) -> List[str]:
def path(self) -> str:
return self._path

@property
def query(self) -> Optional[str]:
return None

@property
def scheme(self) -> str:
return 'gs'
Expand Down
2 changes: 1 addition & 1 deletion hail/python/hailtop/aiotools/fs/copier.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ async def files_iterator() -> AsyncIterator[FileListEntry]:
raise NotADirectoryError(full_dest)

async def copy_source(srcentry: FileListEntry) -> None:
srcfile = srcentry.url_maybe_trailing_slash()
srcfile = await srcentry.url_maybe_trailing_slash()
assert srcfile.startswith(src)

# skip files with empty names
Expand Down
30 changes: 20 additions & 10 deletions hail/python/hailtop/aiotools/fs/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ def name(self) -> str:
async def url(self) -> str:
pass

@abc.abstractmethod
def url_maybe_trailing_slash(self) -> str:
pass
async def url_maybe_trailing_slash(self) -> str:
return await self.url()

async def url_full(self) -> str:
return await self.url()

@abc.abstractmethod
async def is_file(self) -> bool:
Expand Down Expand Up @@ -91,6 +93,11 @@ def bucket_parts(self) -> List[str]:
def path(self) -> str:
pass

@property
@abc.abstractmethod
def query(self) -> Optional[str]:
pass

@property
@abc.abstractmethod
def scheme(self) -> str:
Expand All @@ -101,7 +108,9 @@ def with_path(self, path) -> 'AsyncFSURL':
pass

def with_new_path_component(self, new_path_component) -> 'AsyncFSURL':
return self.with_path(self.path + '/' + new_path_component)
prefix = self.path if self.path.endswith('/') else self.path + '/'
suffix = new_path_component[1:] if new_path_component.startswith('/') else new_path_component
return self.with_path(prefix + suffix)

@abc.abstractmethod
def __str__(self) -> str:
Expand Down Expand Up @@ -132,12 +141,13 @@ async def open(self, url: str) -> ReadableStream:

async def open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream:
if length == 0:
if url.endswith('/'):
file_url = url.rstrip('/')
dir_url = url
fs_url = self.parse_url(url)
if fs_url.path.endswith('/'):
file_url = str(fs_url.with_path(fs_url.path.rstrip('/')))
dir_url = str(fs_url)
else:
file_url = url
dir_url = url + '/'
file_url = str(fs_url)
dir_url = str(fs_url.with_path(fs_url.path + '/'))
isfile, isdir = await asyncio.gather(self.isfile(file_url), self.isdir(dir_url))
if isfile:
if isdir:
Expand Down Expand Up @@ -244,7 +254,7 @@ async def rmtree(self,
async def rm(entry: FileListEntry):
assert listener is not None
listener(1)
await self._remove_doesnt_exist_ok(await entry.url())
await self._remove_doesnt_exist_ok(await entry.url_full())
listener(-1)

try:
Expand Down
8 changes: 6 additions & 2 deletions hail/python/hailtop/aiotools/local_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def url(self) -> str:
trailing_slash = "/" if await self.is_dir() else ""
return f'{self._base_url}{self._entry.name}{trailing_slash}'

def url_maybe_trailing_slash(self) -> str:
async def url_maybe_trailing_slash(self) -> str:
return f'{self._base_url}{self._entry.name}'

async def is_file(self) -> bool:
Expand Down Expand Up @@ -107,6 +107,10 @@ def bucket_parts(self) -> List[str]:
def path(self) -> str:
return self._path

@property
def query(self) -> Optional[str]:
return None

@property
def scheme(self) -> str:
return 'file'
Expand All @@ -115,7 +119,7 @@ def with_path(self, path) -> 'LocalAsyncFSURL':
return LocalAsyncFSURL(path)

def __str__(self) -> str:
return 'file:' + self._path
return self._path


class TruncatedReadableBinaryIO(BinaryIO):
Expand Down
Loading

0 comments on commit d647b6a

Please sign in to comment.