diff --git a/scripts/download_tdr_parquet.py b/scripts/download_tdr_parquet.py new file mode 100644 index 000000000..41f07a147 --- /dev/null +++ b/scripts/download_tdr_parquet.py @@ -0,0 +1,103 @@ +""" +Export Parquet files from TDR and download them to local storage. +""" +from argparse import ( + ArgumentParser, +) +import logging +from pathlib import ( + Path, +) +import sys +from typing import ( + Iterator, +) +from uuid import ( + UUID, +) + +import attrs +from furl import ( + furl, +) + +from azul import ( + cached_property, + config, + reject, +) +from azul.http import ( + HasCachedHttpClient, +) +from azul.logging import ( + configure_script_logging, +) +from azul.terra import ( + TDRClient, + TerraStatusException, +) + +log = logging.getLogger(__name__) + + +@attrs.frozen +class ParquetDownloader(HasCachedHttpClient): + snapshot_id: str + + @cached_property + def tdr(self) -> TDRClient: + return TDRClient.for_indexer() + + def get_download_urls(self) -> dict[str, list[furl]]: + urls = self.tdr.export_parquet_urls(self.snapshot_id) + reject(urls is None, + 'No Parquet access information is available for snapshot %r', self.snapshot_id) + return urls + + def get_data(self, parquet_urls: list[furl]) -> Iterator[bytes]: + for url in parquet_urls: + response = self._http_client.request('GET', str(url)) + if response.status != 200: + raise TerraStatusException(url, response) + if response.headers.get('x-ms-resource-type') == 'directory': + log.info('Skipping Azure directory URL') + else: + yield response.data + + def download_table(self, + table_name: str, + download_urls: list[furl], + location: Path): + data = None + for i, data in enumerate(self.get_data(download_urls)): + output_path = location / f'{self.snapshot_id}_{table_name}_{i}.parquet' + log.info('Writing to %s', output_path) + with open(output_path, 'wb') as f: + f.write(data) + reject(data is None, + 'No Parquet files found for snapshot %r. Tried URLs: %r', + self.snapshot_id, download_urls) + + +def main(argv): + parser = ArgumentParser(add_help=True, description=__doc__) + parser.add_argument('snapshot_id', + type=UUID, + help='The UUID of the snapshot') + parser.add_argument('-O', + '--output-dir', + type=Path, + default=Path(config.project_root) / 'parquet', + help='Where to save the downloaded files') + args = parser.parse_args(argv) + + downloader = ParquetDownloader(args.snapshot_id) + + urls_by_table = downloader.get_download_urls() + for table_name, urls in urls_by_table.items(): + downloader.download_table(table_name, urls, args.output_dir) + + +if __name__ == '__main__': + configure_script_logging(log) + main(sys.argv[1:]) diff --git a/src/azul/terra.py b/src/azul/terra.py index a1080330d..f2e6c2574 100644 --- a/src/azul/terra.py +++ b/src/azul/terra.py @@ -11,6 +11,7 @@ ) import json import logging +import time from time import ( sleep, ) @@ -411,15 +412,19 @@ class TDRSource: @cache def lookup_source(self, source_spec: TDRSourceSpec) -> TDRSource: source = self._lookup_source(source_spec) + region = self._get_region(source, 'bigquery') + return self.TDRSource(project=source['dataProject'], + id=source['id'], + location=region) + + def _get_region(self, source: JSON, resource: str) -> str: storage = one( storage for dataset in (s['dataset'] for s in source['source']) for storage in dataset['storage'] - if storage['cloudResource'] == 'bigquery' + if storage['cloudResource'] == resource ) - return self.TDRSource(project=source['dataProject'], - id=source['id'], - location=storage['region']) + return storage['region'] def _retrieve_source(self, source: SourceRef) -> MutableJSON: endpoint = self._repository_endpoint('snapshots', source.id) @@ -531,7 +536,7 @@ def _check_response(self, endpoint: furl, response: urllib3.HTTPResponse ) -> MutableJSON: - if response.status == 200: + if response.status in (200, 202): return json.loads(response.data) # FIXME: Azul sometimes conflates 401 and 403 # https://github.com/DataBiosphere/azul/issues/4463 @@ -648,3 +653,48 @@ def get_duos(self, source: SourceRef) -> Optional[MutableJSON]: return None else: return self._check_response(url, response) + + def export_parquet_urls(self, + snapshot_id: str + ) -> Optional[dict[str, list[mutable_furl]]]: + """ + Obtain URLs of Parquet files for the data tables of the specified + snapshot. This is an time-consuming operation that usually takes on the + order of 1 minute to complete. + + :param snapshot_id: The UUID of the snapshot. + + :return: A mapping of table names to lists of Parquet file download + URLs, or `None` if if no Parquet downloads are available for + the specified snapshot. The URLs are typically expiring signed + URLs pointing to a cloud storage service such as GCS or Azure. + """ + url = self._repository_endpoint('snapshots', snapshot_id, 'export') + # Required for Azure-backed snapshots + url.args.add('validatePrimaryKeyUniqueness', False) + while True: + response = self._request('GET', url) + response_body = self._check_response(url, response) + jobs_status = response_body['job_status'] + job_id = response_body['id'] + if jobs_status == 'running': + url = self._repository_endpoint('jobs', job_id) + log.info('Waiting for job %r ...', job_id) + time.sleep(2) + elif jobs_status == 'succeeded': + break + else: + raise TerraStatusException(url, response) + url = self._repository_endpoint('jobs', job_id, 'result') + response = self._request('GET', url) + response_body = self._check_response(url, response) + parquet = response_body['format'].get('parquet') + if parquet is not None: + region = self._get_region(response_body['snapshot'], 'bucket') + require(config.tdr_source_location == region, + config.tdr_source_location, region) + parquet = { + table['name']: list(map(furl, table['paths'])) + for table in parquet['location']['tables'] + } + return parquet