Skip to content

Commit

Permalink
Add AwsToAwsBaseOperator (#30044)
Browse files Browse the repository at this point in the history
* Add `AwsToAwsBaseOperator`
followup on #29452 (comment)
This PR preserve all current behavior but add the needed interface to be used for other transfer operators
  • Loading branch information
eladkal authored Mar 14, 2023
1 parent 848a396 commit 4effd6f
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 33 deletions.
69 changes: 69 additions & 0 deletions airflow/providers/amazon/aws/transfers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains base AWS to AWS transfer operator"""
from __future__ import annotations

import warnings
from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.types import NOTSET, ArgNotSet

_DEPRECATION_MSG = (
"The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead."
)


class AwsToAwsBaseOperator(BaseOperator):
"""
Base class for AWS to AWS transfer operators
:param source_aws_conn_id: The Airflow connection used for AWS credentials
to access DynamoDB. If this is None or empty then the default boto3
behaviour is used. If running Airflow in a distributed manner and
source_aws_conn_id is None or empty, then default boto3 configuration
would be used (and must be maintained on each worker node).
:param dest_aws_conn_id: The Airflow connection used for AWS credentials
to access S3. If this is not set then the source_aws_conn_id connection is used.
:param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id).
"""

template_fields: Sequence[str] = (
"source_aws_conn_id",
"dest_aws_conn_id",
)

def __init__(
self,
*,
source_aws_conn_id: str | None = AwsBaseHook.default_conn_name,
dest_aws_conn_id: str | None | ArgNotSet = NOTSET,
aws_conn_id: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> None:
super().__init__(**kwargs)
if not isinstance(aws_conn_id, ArgNotSet):
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
self.source_aws_conn_id = aws_conn_id
else:
self.source_aws_conn_id = source_aws_conn_id
self.dest_aws_conn_id = (
self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id
)
35 changes: 4 additions & 31 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,22 @@
from __future__ import annotations

import json
import warnings
from copy import copy
from decimal import Decimal
from os.path import getsize
from tempfile import NamedTemporaryFile
from typing import IO, TYPE_CHECKING, Any, Callable, Sequence
from uuid import uuid4

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


_DEPRECATION_MSG = (
"The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead."
)


class JSONEncoder(json.JSONEncoder):
"""Custom json encoder implementation"""

Expand Down Expand Up @@ -74,7 +67,7 @@ def _upload_file_to_s3(
)


class DynamoDBToS3Operator(BaseOperator):
class DynamoDBToS3Operator(AwsToAwsBaseOperator):
"""
Replicates records from a DynamoDB table to S3.
It scans a DynamoDB table and writes the received records to a file
Expand All @@ -89,29 +82,20 @@ class DynamoDBToS3Operator(BaseOperator):
:ref:`howto/transfer:DynamoDBToS3Operator`
:param dynamodb_table_name: Dynamodb table to replicate data from
:param source_aws_conn_id: The Airflow connection used for AWS credentials
to access DynamoDB. If this is None or empty then the default boto3
behaviour is used. If running Airflow in a distributed manner and
source_aws_conn_id is None or empty, then default boto3 configuration
would be used (and must be maintained on each worker node).
:param s3_bucket_name: S3 bucket to replicate data to
:param file_size: Flush file to s3 if file size >= file_size
:param dynamodb_scan_kwargs: kwargs pass to <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.scan>
:param s3_key_prefix: Prefix of s3 object key
:param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
:param dest_aws_conn_id: The Airflow connection used for AWS credentials
to access S3. If this is not set then the source_aws_conn_id connection is used.
:param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id).
""" # noqa: E501

template_fields: Sequence[str] = (
"source_aws_conn_id",
"dest_aws_conn_id",
*AwsToAwsBaseOperator.template_fields,
"s3_bucket_name",
"s3_key_prefix",
"dynamodb_table_name",
)

template_fields_renderers = {
"dynamodb_scan_kwargs": "json",
}
Expand All @@ -120,14 +104,11 @@ def __init__(
self,
*,
dynamodb_table_name: str,
source_aws_conn_id: str | None = AwsBaseHook.default_conn_name,
s3_bucket_name: str,
file_size: int,
dynamodb_scan_kwargs: dict[str, Any] | None = None,
s3_key_prefix: str = "",
process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes,
dest_aws_conn_id: str | None | ArgNotSet = NOTSET,
aws_conn_id: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -137,14 +118,6 @@ def __init__(
self.dynamodb_scan_kwargs = dynamodb_scan_kwargs
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = s3_key_prefix
if not isinstance(aws_conn_id, ArgNotSet):
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
self.source_aws_conn_id = aws_conn_id
else:
self.source_aws_conn_id = source_aws_conn_id
self.dest_aws_conn_id = (
self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id
)

def execute(self, context: Context) -> None:
hook = DynamoDBHook(aws_conn_id=self.source_aws_conn_id)
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/amazon/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,9 @@ transfers:
target-integration-name: Common SQL
how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/s3_to_sql.rst
python-module: airflow.providers.amazon.aws.transfers.s3_to_sql

- source-integration-name: Amazon Web Services
target-integration-name: Amazon Web Services
python-module: airflow.providers.amazon.aws.transfers.base

extra-links:
- airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink
Expand Down
1 change: 1 addition & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest):
"airflow.providers.amazon.aws.operators.ecs.EcsBaseOperator",
"airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor",
"airflow.providers.amazon.aws.sensors.eks.EksBaseSensor",
"airflow.providers.amazon.aws.transfers.base.AwsToAwsBaseOperator",
}

MISSING_EXAMPLES_FOR_CLASSES = {
Expand Down
58 changes: 58 additions & 0 deletions tests/providers/amazon/aws/transfers/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow import DAG
from airflow.models import DagRun, TaskInstance
from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2020, 1, 1)


class TestAwsToAwsBaseOperator:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG("test_dag_id", default_args=args)

def test_render_template(self):
operator = AwsToAwsBaseOperator(
task_id="dynamodb_to_s3_test_render",
dag=self.dag,
source_aws_conn_id="{{ ds }}",
dest_aws_conn_id="{{ ds }}",
)
ti = TaskInstance(operator, run_id="something")
ti.dag_run = DagRun(run_id="something", execution_date=timezone.datetime(2020, 1, 1))
ti.render_templates()
assert "2020-01-01" == getattr(operator, "source_aws_conn_id")
assert "2020-01-01" == getattr(operator, "dest_aws_conn_id")

def test_deprecation(self):
with pytest.warns(
DeprecationWarning,
match="The aws_conn_id parameter has been deprecated."
" Use the source_aws_conn_id parameter instead.",
):
AwsToAwsBaseOperator(
task_id="transfer",
dag=self.dag,
aws_conn_id="my_conn",
)
29 changes: 28 additions & 1 deletion tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@
from __future__ import annotations

import json
from datetime import datetime
from decimal import Decimal
from unittest.mock import MagicMock, patch

import pytest

from airflow import DAG
from airflow.models import DagRun, TaskInstance
from airflow.providers.amazon.aws.transfers.base import _DEPRECATION_MSG
from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import (
_DEPRECATION_MSG,
DynamoDBToS3Operator,
JSONEncoder,
)
from airflow.utils import timezone


class TestJSONEncoder:
Expand Down Expand Up @@ -288,3 +292,26 @@ def test_dynamodb_to_s3_with_just_dest_aws_conn_id(self, mock_aws_dynamodb_hook,

mock_aws_dynamodb_hook.assert_called_with(aws_conn_id="aws_default")
mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id)

def test_render_template(self):
dag = DAG("test_render_template_dag_id", start_date=datetime(2020, 1, 1))
operator = DynamoDBToS3Operator(
task_id="dynamodb_to_s3_test_render",
dag=dag,
dynamodb_table_name="{{ ds }}",
s3_key_prefix="{{ ds }}",
s3_bucket_name="{{ ds }}",
file_size=4000,
source_aws_conn_id="{{ ds }}",
dest_aws_conn_id="{{ ds }}",
)
ti = TaskInstance(operator, run_id="something")
ti.dag_run = DagRun(
dag_id=dag.dag_id, run_id="something", execution_date=timezone.datetime(2020, 1, 1)
)
ti.render_templates()
assert "2020-01-01" == getattr(operator, "source_aws_conn_id")
assert "2020-01-01" == getattr(operator, "dest_aws_conn_id")
assert "2020-01-01" == getattr(operator, "s3_bucket_name")
assert "2020-01-01" == getattr(operator, "dynamodb_table_name")
assert "2020-01-01" == getattr(operator, "s3_key_prefix")

0 comments on commit 4effd6f

Please sign in to comment.