diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index 6d205308f3a54..43f5df75265c5 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -91,9 +91,9 @@ def copy(self, source_bucket, source_object, destination_bucket=None, raise ValueError('source_bucket and source_object cannot be empty.') client = self.get_conn() - source_bucket = client.get_bucket(source_bucket) + source_bucket = client.bucket(source_bucket) source_object = source_bucket.blob(source_object) - destination_bucket = client.get_bucket(destination_bucket) + destination_bucket = client.bucket(destination_bucket) destination_object = source_bucket.copy_blob( blob=source_object, destination_bucket=destination_bucket, @@ -133,9 +133,9 @@ def rewrite(self, source_bucket, source_object, destination_bucket, raise ValueError('source_bucket and source_object cannot be empty.') client = self.get_conn() - source_bucket = client.get_bucket(source_bucket) + source_bucket = client.bucket(source_bucket) source_object = source_bucket.blob(blob_name=source_object) - destination_bucket = client.get_bucket(destination_bucket) + destination_bucket = client.bucket(destination_bucket) token, bytes_rewritten, total_bytes = destination_bucket.blob( blob_name=destination_object).rewrite( @@ -169,7 +169,7 @@ def download(self, bucket_name, object_name, filename=None): :type filename: str """ client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) if filename: @@ -204,7 +204,7 @@ def upload(self, bucket_name, object_name, filename, filename = filename_gz client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) blob.upload_from_filename(filename=filename, content_type=mime_type) @@ -224,7 +224,7 @@ def exists(self, bucket_name, object_name): :type object_name: str """ client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) return blob.exists() @@ -241,9 +241,12 @@ def is_updated_after(self, bucket_name, object_name, ts): :type ts: datetime.datetime """ client = self.get_conn() - bucket = storage.Bucket(client=client, name=bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) - blob.reload() + + if blob is None: + raise ValueError("Object ({}) not found in Bucket ({})".format( + object_name, bucket_name)) blob_update_time = blob.updated @@ -270,7 +273,7 @@ def delete(self, bucket_name, object_name): :type object_name: str """ client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.blob(blob_name=object_name) blob.delete() @@ -294,7 +297,7 @@ def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimi :return: a stream of object names matching the filtering criteria """ client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) ids = [] page_token = None @@ -338,9 +341,8 @@ def get_size(self, bucket_name, object_name): object_name, bucket_name) client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) - blob.reload() blob_size = blob.size self.log.info('The file size of %s is %s bytes.', object_name, blob_size) return blob_size @@ -358,9 +360,8 @@ def get_crc32c(self, bucket_name, object_name): self.log.info('Retrieving the crc32c checksum of ' 'object_name: %s in bucket_name: %s', object_name, bucket_name) client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) - blob.reload() blob_crc32c = blob.crc32c self.log.info('The crc32c checksum of %s is %s', object_name, blob_crc32c) return blob_crc32c @@ -378,9 +379,8 @@ def get_md5hash(self, bucket_name, object_name): self.log.info('Retrieving the MD5 hash of ' 'object: %s in bucket: %s', object_name, bucket_name) client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) blob = bucket.get_blob(blob_name=object_name) - blob.reload() blob_md5hash = blob.md5_hash self.log.info('The md5Hash of %s is %s', object_name, blob_md5hash) return blob_md5hash @@ -550,7 +550,7 @@ def compose(self, bucket_name, source_objects, destination_object): self.log.info("Composing %s to %s in the bucket %s", source_objects, destination_object, bucket_name) client = self.get_conn() - bucket = client.get_bucket(bucket_name) + bucket = client.bucket(bucket_name) destination_blob = bucket.blob(destination_object) destination_blob.compose( sources=[ diff --git a/tests/contrib/hooks/test_gcs_hook.py b/tests/contrib/hooks/test_gcs_hook.py index 1714a228a1213..23b0f9e2125e0 100644 --- a/tests/contrib/hooks/test_gcs_hook.py +++ b/tests/contrib/hooks/test_gcs_hook.py @@ -20,7 +20,9 @@ import os import tempfile import unittest +from datetime import datetime +import dateutil from google.cloud import storage from google.cloud import exceptions @@ -69,14 +71,25 @@ def setUp(self): self.gcs_hook = gcs_hook.GoogleCloudStorageHook( google_cloud_storage_conn_id='test') + def test_storage_client_creation(self): + with mock.patch('google.cloud.storage.Client') as mock_client: + gcs_hook_1 = gcs_hook.GoogleCloudStorageHook() + gcs_hook_1.get_conn() + + # test that Storage Client is called with required arguments + mock_client.assert_called_once_with( + client_info=mock.ANY, + credentials=mock.ANY, + project=mock.ANY) + @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_exists(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' # Given - get_bucket_mock = mock_service.return_value.get_bucket - blob_object = get_bucket_mock.return_value.blob + bucket_mock = mock_service.return_value.bucket + blob_object = bucket_mock.return_value.blob exists_method = blob_object.return_value.exists exists_method.return_value = True @@ -85,7 +98,7 @@ def test_exists(self, mock_service): # Then self.assertTrue(response) - get_bucket_mock.assert_called_once_with(test_bucket) + bucket_mock.assert_called_once_with(test_bucket) blob_object.assert_called_once_with(blob_name=test_object) exists_method.assert_called_once_with() @@ -95,8 +108,8 @@ def test_exists_nonexisting_object(self, mock_service): test_object = 'test_object' # Given - get_bucket_mock = mock_service.return_value.get_bucket - blob_object = get_bucket_mock.return_value.blob + bucket_mock = mock_service.return_value.bucket + blob_object = bucket_mock.return_value.blob exists_method = blob_object.return_value.exists exists_method.return_value = False @@ -106,6 +119,24 @@ def test_exists_nonexisting_object(self, mock_service): # Then self.assertFalse(response) + @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) + def test_is_updated_after(self, mock_service): + test_bucket = 'test_bucket' + test_object = 'test_object' + + # Given + mock_service.return_value.bucket.return_value.get_blob\ + .return_value.updated = datetime(2019, 8, 28, 14, 7, 20, 700000, dateutil.tz.tzutc()) + + # When + response = self.gcs_hook.is_updated_after( + bucket_name=test_bucket, object_name=test_object, + ts=datetime(2018, 1, 1, 1, 1, 1) + ) + + # Then + self.assertTrue(response) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_copy(self, mock_service, mock_bucket): @@ -121,9 +152,9 @@ def test_copy(self, mock_service, mock_bucket): name=destination_object) # Given - get_bucket_mock = mock_service.return_value.get_bucket - get_bucket_mock.return_value = mock_bucket - copy_method = get_bucket_mock.return_value.copy_blob + bucket_mock = mock_service.return_value.bucket + bucket_mock.return_value = mock_bucket + copy_method = bucket_mock.return_value.copy_blob copy_method.return_value = destination_blob # When @@ -206,9 +237,9 @@ def test_rewrite(self, mock_service, mock_bucket): source_blob = mock_bucket.blob(source_object) # Given - get_bucket_mock = mock_service.return_value.get_bucket - get_bucket_mock.return_value = mock_bucket - get_blob_method = get_bucket_mock.return_value.blob + bucket_mock = mock_service.return_value.bucket + bucket_mock.return_value = mock_bucket + get_blob_method = bucket_mock.return_value.blob rewrite_method = get_blob_method.return_value.rewrite rewrite_method.side_effect = [(None, mock.ANY, mock.ANY), (mock.ANY, mock.ANY, mock.ANY)] @@ -280,8 +311,8 @@ def test_delete_nonexisting_object(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - get_bucket_method = mock_service.return_value.get_bucket - blob = get_bucket_method.return_value.blob + bucket_method = mock_service.return_value.bucket + blob = bucket_method.return_value.blob delete_method = blob.return_value.delete delete_method.side_effect = exceptions.NotFound(message="Not Found") @@ -294,15 +325,14 @@ def test_object_get_size(self, mock_service): test_object = 'test_object' returned_file_size = 1200 - get_bucket_method = mock_service.return_value.get_bucket - get_blob_method = get_bucket_method.return_value.get_blob + bucket_method = mock_service.return_value.bucket + get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.size = returned_file_size response = self.gcs_hook.get_size(bucket_name=test_bucket, object_name=test_object) self.assertEqual(response, returned_file_size) - get_blob_method.return_value.reload.assert_called_once_with() @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_object_get_crc32c(self, mock_service): @@ -310,8 +340,8 @@ def test_object_get_crc32c(self, mock_service): test_object = 'test_object' returned_file_crc32c = "xgdNfQ==" - get_bucket_method = mock_service.return_value.get_bucket - get_blob_method = get_bucket_method.return_value.get_blob + bucket_method = mock_service.return_value.bucket + get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.crc32c = returned_file_crc32c response = self.gcs_hook.get_crc32c(bucket_name=test_bucket, @@ -319,17 +349,14 @@ def test_object_get_crc32c(self, mock_service): self.assertEqual(response, returned_file_crc32c) - # Check that reload method is called - get_blob_method.return_value.reload.assert_called_once_with() - @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_object_get_md5hash(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' returned_file_md5hash = "leYUJBUWrRtks1UeUFONJQ==" - get_bucket_method = mock_service.return_value.get_bucket - get_blob_method = get_bucket_method.return_value.get_blob + bucket_method = mock_service.return_value.bucket + get_blob_method = bucket_method.return_value.get_blob get_blob_method.return_value.md5_hash = returned_file_md5hash response = self.gcs_hook.get_md5hash(bucket_name=test_bucket, @@ -337,9 +364,6 @@ def test_object_get_md5hash(self, mock_service): self.assertEqual(response, returned_file_md5hash) - # Check that reload method is called - get_blob_method.return_value.reload.assert_called_once_with() - @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_create_bucket(self, mock_service, mock_bucket): @@ -416,9 +440,9 @@ def test_compose(self, mock_service, mock_blob): test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3'] test_destination_object = 'test_object_composed' - mock_service.return_value.get_bucket.return_value\ + mock_service.return_value.bucket.return_value\ .blob.return_value = mock_blob(blob_name=mock.ANY) - method = mock_service.return_value.get_bucket.return_value.blob\ + method = mock_service.return_value.bucket.return_value.blob\ .return_value.compose self.gcs_hook.compose( @@ -492,7 +516,7 @@ def test_download_as_string(self, mock_service): test_object = 'test_object' test_object_bytes = io.BytesIO(b"input") - download_method = mock_service.return_value.get_bucket.return_value \ + download_method = mock_service.return_value.bucket.return_value \ .blob.return_value.download_as_string download_method.return_value = test_object_bytes @@ -510,11 +534,11 @@ def test_download_to_file(self, mock_service): test_object_bytes = io.BytesIO(b"input") test_file = 'test_file' - download_filename_method = mock_service.return_value.get_bucket.return_value \ + download_filename_method = mock_service.return_value.bucket.return_value \ .blob.return_value.download_to_filename download_filename_method.return_value = None - download_as_a_string_method = mock_service.return_value.get_bucket.return_value \ + download_as_a_string_method = mock_service.return_value.bucket.return_value \ .blob.return_value.download_as_string download_as_a_string_method.return_value = test_object_bytes @@ -546,7 +570,7 @@ def test_upload(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - upload_method = mock_service.return_value.get_bucket.return_value\ + upload_method = mock_service.return_value.bucket.return_value\ .blob.return_value.upload_from_filename upload_method.return_value = None