diff --git a/airflow/__init__.py b/airflow/__init__.py index db3fcd611c740..1ed188cc45886 100644 --- a/airflow/__init__.py +++ b/airflow/__init__.py @@ -31,7 +31,7 @@ from airflow.models import DAG from flask_admin import BaseView from importlib import import_module -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) if DAGS_FOLDER not in sys.path: diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index d9a8d0667e44a..1531c18aaa5dc 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -14,11 +14,14 @@ import json import airflow -from airflow import jobs, settings, utils +from airflow import jobs, settings from airflow import configuration as conf from airflow.executors import DEFAULT_EXECUTOR from airflow.models import DagModel, DagBag, TaskInstance, DagPickle, DagRun -from airflow.utils import AirflowException, State +from airflow.utils import db as db_utils +from airflow.utils import logging as logging_utils +from airflow.utils.state import State +from airflow.exceptions import AirflowException DAGS_FOLDER = os.path.expanduser(conf.get('core', 'DAGS_FOLDER')) @@ -78,7 +81,8 @@ def backfill(args, dag=None): mark_success=args.mark_success, include_adhoc=args.include_adhoc, local=args.local, - donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')), + donot_pickle=(args.donot_pickle or + conf.getboolean('core', 'donot_pickle')), ignore_dependencies=args.ignore_dependencies, pool=args.pool) @@ -133,7 +137,7 @@ def set_is_paused(is_paused, args, dag=None): def run(args, dag=None): - utils.pessimistic_connection_handling() + db_utils.pessimistic_connection_handling() if dag: args.dag_id = dag.dag_id @@ -236,10 +240,10 @@ def run(args, dag=None): remote_log_location = filename.replace(log_base, remote_base) # S3 if remote_base.startswith('s3:/'): - utils.S3Log().write(log, remote_log_location) + logging_utils.S3Log().write(log, remote_log_location) # GCS elif remote_base.startswith('gs:/'): - utils.GCSLog().write( + logging_utils.GCSLog().write( log, remote_log_location, append=True) @@ -401,7 +405,7 @@ def worker(args): def initdb(args): # noqa print("DB: " + repr(settings.engine.url)) - utils.initdb() + db_utils.initdb() print("Done.") @@ -412,14 +416,14 @@ def resetdb(args): "Proceed? (y/n)").upper() == "Y": logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) - utils.resetdb() + db_utils.resetdb() else: print("Bail.") def upgradedb(args): # noqa print("DB: " + repr(settings.engine.url)) - utils.upgradedb() + db_utils.upgradedb() def version(args): # noqa diff --git a/airflow/configuration.py b/airflow/configuration.py index 47f2262e72bf7..5a50b828d96da 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -16,6 +16,7 @@ from collections import OrderedDict from configparser import ConfigParser + class AirflowConfigException(Exception): pass @@ -612,6 +613,7 @@ def test_mode(): def get(section, key, **kwargs): return conf.get(section, key, **kwargs) + def getboolean(section, key): return conf.getboolean(section, key) @@ -644,5 +646,6 @@ def set(section, option, value): # noqa ######################## # convenience method to access config entries + def get_dags_folder(): return os.path.expanduser(get('core', 'DAGS_FOLDER')) diff --git a/airflow/contrib/executors/mesos_executor.py b/airflow/contrib/executors/mesos_executor.py index 3b82306f2bd5a..45a474dc3a1a3 100644 --- a/airflow/contrib/executors/mesos_executor.py +++ b/airflow/contrib/executors/mesos_executor.py @@ -11,8 +11,8 @@ from airflow import configuration from airflow.executors.base_executor import BaseExecutor from airflow.settings import Session -from airflow.utils import State -from airflow.utils import AirflowException +from airflow.utils.state import State +from airflow.exceptions import AirflowException DEFAULT_FRAMEWORK_NAME = 'Airflow' diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 1507b09fef276..46c6a806dd099 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -1,6 +1,6 @@ # Imports the hooks dynamically while keeping the package API clean, # abstracting the underlying modules -from airflow.utils import import_module_attrs as _import_module_attrs +from airflow.utils.helpers import import_module_attrs as _import_module_attrs _hooks = { 'ftp_hook': ['FTPHook'], diff --git a/airflow/contrib/hooks/gc_base_hook.py b/airflow/contrib/hooks/gc_base_hook.py index b17d37fed5fe0..6af01e79eeffb 100644 --- a/airflow/contrib/hooks/gc_base_hook.py +++ b/airflow/contrib/hooks/gc_base_hook.py @@ -2,7 +2,7 @@ import logging from airflow.hooks.base_hook import BaseHook -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from oauth2client.client import SignedJwtAssertionCredentials, GoogleCredentials class GoogleCloudBaseHook(BaseHook): diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py index 7e2fdb32dcbcc..c36b9f5838f5e 100755 --- a/airflow/contrib/hooks/qubole_hook.py +++ b/airflow/contrib/hooks/qubole_hook.py @@ -3,7 +3,7 @@ import datetime import logging -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook from airflow import configuration @@ -151,4 +151,4 @@ def create_cmd_args(self): else: args += inplace_args.split(' ') - return args \ No newline at end of file + return args diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py index bc590fe2789f8..c5aff848bd865 100644 --- a/airflow/contrib/hooks/ssh_hook.py +++ b/airflow/contrib/hooks/ssh_hook.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from airflow.hooks.base_hook import BaseHook -from airflow import AirflowException +from airflow.exceptions import AirflowException import logging diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py index f178e392f2870..3598490a8eeda 100644 --- a/airflow/contrib/operators/__init__.py +++ b/airflow/contrib/operators/__init__.py @@ -1,6 +1,6 @@ # Imports the operators dynamically while keeping the package API clean, # abstracting the underlying modules -from airflow.utils import import_module_attrs as _import_module_attrs +from airflow.utils.helpers import import_module_attrs as _import_module_attrs _operators = { 'ssh_execute_operator': ['SSHExecuteOperator'], diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py index 69b69c37f36da..218de5a10ec45 100644 --- a/airflow/contrib/operators/bigquery_check_operator.py +++ b/airflow/contrib/operators/bigquery_check_operator.py @@ -1,6 +1,6 @@ from airflow.contrib.hooks.bigquery_hook import BigQueryHook from airflow.operators import CheckOperator, ValueCheckOperator, IntervalCheckOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class BigQueryCheckOperator(CheckOperator): diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py index 78edde2ca963c..2f60ac60ab06d 100644 --- a/airflow/contrib/operators/bigquery_operator.py +++ b/airflow/contrib/operators/bigquery_operator.py @@ -2,7 +2,7 @@ from airflow.contrib.hooks.bigquery_hook import BigQueryHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class BigQueryOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py index ccb9b0715f894..56023f584041b 100644 --- a/airflow/contrib/operators/bigquery_to_bigquery.py +++ b/airflow/contrib/operators/bigquery_to_bigquery.py @@ -2,7 +2,7 @@ from airflow.contrib.hooks.bigquery_hook import BigQueryHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class BigQueryToBigQueryOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py index 012b3bb4a6ffb..3a543fdb3052c 100644 --- a/airflow/contrib/operators/bigquery_to_gcs.py +++ b/airflow/contrib/operators/bigquery_to_gcs.py @@ -2,7 +2,7 @@ from airflow.contrib.hooks.bigquery_hook import BigQueryHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class BigQueryToCloudStorageOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py index ef917e85de219..8de6d1728af55 100644 --- a/airflow/contrib/operators/gcs_download_operator.py +++ b/airflow/contrib/operators/gcs_download_operator.py @@ -2,7 +2,7 @@ from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class GoogleCloudStorageDownloadOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py index bcb418c636ac7..ec3a459238bb1 100644 --- a/airflow/contrib/operators/gcs_to_bq.py +++ b/airflow/contrib/operators/gcs_to_bq.py @@ -4,7 +4,7 @@ from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.contrib.hooks.bigquery_hook import BigQueryHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class GoogleCloudStorageToBigQueryOperator(BaseOperator): """ diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py index cee634ee12871..0eb368e3fa95e 100644 --- a/airflow/contrib/operators/mysql_to_gcs.py +++ b/airflow/contrib/operators/mysql_to_gcs.py @@ -5,7 +5,7 @@ from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.hooks import MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults from collections import OrderedDict from datetime import date, datetime from decimal import Decimal diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index 9c05479827b6b..2ed94e1023402 100755 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -1,5 +1,5 @@ from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults from airflow.contrib.hooks import QuboleHook diff --git a/airflow/contrib/operators/ssh_execute_operator.py b/airflow/contrib/operators/ssh_execute_operator.py index 0c20719660cba..c55f0d177c3b2 100644 --- a/airflow/contrib/operators/ssh_execute_operator.py +++ b/airflow/contrib/operators/ssh_execute_operator.py @@ -4,8 +4,8 @@ from subprocess import STDOUT from airflow.models import BaseOperator -from airflow.utils import apply_defaults -from airflow.utils import AirflowException +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException class SSHTempFileContent(): diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py index 08003114d5fde..9e5248f03fcb3 100644 --- a/airflow/contrib/operators/vertica_operator.py +++ b/airflow/contrib/operators/vertica_operator.py @@ -2,7 +2,7 @@ from airflow.contrib.hooks import VerticaHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class VerticaOperator(BaseOperator): diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py index 17a59680b76d8..35a489a9beba3 100644 --- a/airflow/contrib/operators/vertica_to_hive.py +++ b/airflow/contrib/operators/vertica_to_hive.py @@ -7,7 +7,7 @@ from airflow.hooks import HiveCliHook from airflow.contrib.hooks import VerticaHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class VerticaToHiveTransfer(BaseOperator): """ diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py index 3a4c6315eae9b..967c65edc3f5c 100644 --- a/airflow/example_dags/example_short_circuit_operator.py +++ b/airflow/example_dags/example_short_circuit_operator.py @@ -1,6 +1,6 @@ from airflow.operators import ShortCircuitOperator, DummyOperator from airflow.models import DAG -import airflow.utils +import airflow.utils.helpers from datetime import datetime, timedelta seven_days_ago = datetime.combine(datetime.today() - timedelta(7), @@ -21,5 +21,5 @@ ds_true = [DummyOperator(task_id='true_' + str(i), dag=dag) for i in [1, 2]] ds_false = [DummyOperator(task_id='false_' + str(i), dag=dag) for i in [1, 2]] -airflow.utils.chain(cond_true, *ds_true) -airflow.utils.chain(cond_false, *ds_false) +airflow.utils.helpers.chain(cond_true, *ds_true) +airflow.utils.helpers.chain(cond_false, *ds_false) diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py index 3b463a4d23c87..657672e305198 100644 --- a/airflow/example_dags/example_trigger_controller_dag.py +++ b/airflow/example_dags/example_trigger_controller_dag.py @@ -1,3 +1,4 @@ + """This example illustrates the use of the TriggerDagRunOperator. There are 2 entities at work in this scenario: 1. The Controller DAG - the DAG that conditionally executes the trigger @@ -14,6 +15,7 @@ state is then made available to the TargetDag 2. A Target DAG : c.f. example_trigger_target_dag.py """ + from airflow import DAG from airflow.operators import TriggerDagRunOperator from datetime import datetime @@ -35,8 +37,8 @@ def conditionally_trigger(context, dag_run_obj): # Define the DAG dag = DAG(dag_id='example_trigger_controller_dag', - default_args={"owner" : "me", - "start_date":datetime.now()}, + default_args={"owner": "me", + "start_date": datetime.now()}, schedule_interval='@once') diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py index 9d548813b10f0..172003f05fc2a 100644 --- a/airflow/example_dags/example_trigger_target_dag.py +++ b/airflow/example_dags/example_trigger_target_dag.py @@ -34,7 +34,7 @@ def run_this_func(ds, **kwargs): - print( "Remotely received value of {} for key=message".format(kwargs['dag_run'].conf['message'])) + print("Remotely received value of {} for key=message".format(kwargs['dag_run'].conf['message'])) run_this = PythonOperator( task_id='run_this', diff --git a/airflow/exceptions.py b/airflow/exceptions.py new file mode 100644 index 0000000000000..2468643ee47b7 --- /dev/null +++ b/airflow/exceptions.py @@ -0,0 +1,10 @@ +class AirflowException(Exception): + pass + + +class AirflowSensorTimeout(Exception): + pass + + +class AirflowTaskTimeout(Exception): + pass diff --git a/airflow/executors/__init__.py b/airflow/executors/__init__.py index 695ef5ce7f24b..31635a1374725 100644 --- a/airflow/executors/__init__.py +++ b/airflow/executors/__init__.py @@ -10,7 +10,7 @@ except: pass -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException _EXECUTOR = configuration.get('core', 'EXECUTOR') diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index db68860d66ab1..a0c26ebcbb237 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -1,7 +1,8 @@ from builtins import range from airflow import configuration -from airflow.utils import State, LoggingMixin +from airflow.utils.state import State +from airflow.utils.logging import LoggingMixin PARALLELISM = configuration.getint('core', 'PARALLELISM') diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 5b3cd9da98818..088cb0b488ed6 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -6,7 +6,7 @@ from celery import Celery from celery import states as celery_states -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor from airflow import configuration diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index 19ada6a799946..15a89e169cd64 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -6,7 +6,8 @@ from airflow import configuration from airflow.executors.base_executor import BaseExecutor -from airflow.utils import State, LoggingMixin +from airflow.utils.state import State +from airflow.utils.logging import LoggingMixin PARALLELISM = configuration.get('core', 'PARALLELISM') diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py index 4684226a1aa5f..53d9f0a626ea5 100644 --- a/airflow/executors/sequential_executor.py +++ b/airflow/executors/sequential_executor.py @@ -1,9 +1,8 @@ from builtins import str -import logging import subprocess from airflow.executors.base_executor import BaseExecutor -from airflow.utils import State +from airflow.utils.state import State class SequentialExecutor(BaseExecutor): diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py index 00b6a0cdcacdd..40ac1fb98f8ce 100644 --- a/airflow/hooks/S3_hook.py +++ b/airflow/hooks/S3_hook.py @@ -15,7 +15,7 @@ boto.set_stream_logger('boto') logging.getLogger("boto").setLevel(logging.INFO) -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py index d4c4c27141e2f..58fac177e84d8 100644 --- a/airflow/hooks/__init__.py +++ b/airflow/hooks/__init__.py @@ -1,6 +1,7 @@ # Imports the hooks dynamically while keeping the package API clean, # abstracting the underlying modules -from airflow.utils import import_module_attrs as _import_module_attrs + +from airflow.utils.helpers import import_module_attrs as _import_module_attrs from airflow.hooks.base_hook import BaseHook # noqa to expose in package _hooks = { diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py index 15fadad83e5b1..2a6cb734d7bda 100644 --- a/airflow/hooks/base_hook.py +++ b/airflow/hooks/base_hook.py @@ -10,7 +10,7 @@ from airflow import settings from airflow.models import Connection -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException CONN_ENV_PREFIX = 'AIRFLOW_CONN_' diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py index 1c98deaf1cb19..10c5acd0b94df 100644 --- a/airflow/hooks/dbapi_hook.py +++ b/airflow/hooks/dbapi_hook.py @@ -6,7 +6,7 @@ import logging from airflow.hooks.base_hook import BaseHook -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException class DbApiHook(BaseHook): diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py index 3c216d9d798a4..f01b1e39ffc19 100644 --- a/airflow/hooks/druid_hook.py +++ b/airflow/hooks/druid_hook.py @@ -7,7 +7,7 @@ import requests from airflow.hooks.base_hook import BaseHook -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException LOAD_CHECK_INTERVAL = 5 diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py index f02cc7cffc8d8..3885bbd05c45a 100644 --- a/airflow/hooks/hdfs_hook.py +++ b/airflow/hooks/hdfs_hook.py @@ -7,7 +7,7 @@ except ImportError: snakebite_imported = False -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException class HDFSHookException(AirflowException): diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py index 6e3318abce82d..15d1c98b33057 100644 --- a/airflow/hooks/hive_hooks.py +++ b/airflow/hooks/hive_hooks.py @@ -22,9 +22,9 @@ import subprocess from tempfile import NamedTemporaryFile -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook -from airflow.utils import TemporaryDirectory +from airflow.utils.file import TemporaryDirectory from airflow import configuration import airflow.security.utils as utils @@ -411,7 +411,7 @@ def table_exists(self, table_name, db='default'): class HiveServer2Hook(BaseHook): """ - Wrapper around the impala library + Wrapper around the impyla library Note that the default authMechanism is PLAIN, to override it you can specify it in the ``extra`` of your connection in the UI as in diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py index 4d0eb71790ddf..07cf9f264931a 100644 --- a/airflow/hooks/http_hook.py +++ b/airflow/hooks/http_hook.py @@ -4,7 +4,7 @@ import requests from airflow.hooks.base_hook import BaseHook -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException class HttpHook(BaseHook): diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py index 4d63fee7833a6..5b40e52536c39 100644 --- a/airflow/hooks/pig_hook.py +++ b/airflow/hooks/pig_hook.py @@ -3,9 +3,9 @@ import subprocess from tempfile import NamedTemporaryFile -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks.base_hook import BaseHook -from airflow.utils import TemporaryDirectory +from airflow.utils.file import TemporaryDirectory from airflow import configuration diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py index 83e6eaa54b39c..79a23bc38cee2 100644 --- a/airflow/hooks/webhdfs_hook.py +++ b/airflow/hooks/webhdfs_hook.py @@ -11,7 +11,7 @@ except ImportError: logging.error("Could not load the Kerberos extension for the WebHDFSHook.") raise -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException class AirflowWebHDFSHookException(AirflowException): diff --git a/airflow/jobs.py b/airflow/jobs.py index 5a939ed072d41..9ab47ba02c3c7 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -33,9 +33,14 @@ from sqlalchemy import Column, Integer, String, DateTime, func, Index, or_ from sqlalchemy.orm.session import make_transient -from airflow import executors, models, settings, utils +from airflow import executors, models, settings from airflow import configuration as conf -from airflow.utils import AirflowException, State, LoggingMixin +from airflow.exceptions import AirflowException +from airflow.utils.state import State +from airflow.utils.db import provide_session, pessimistic_connection_handling +from airflow.utils.email import send_email +from airflow.utils.logging import LoggingMixin +from airflow.utils import asciiart Base = models.Base @@ -233,7 +238,7 @@ def __init__( self.heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC') - @utils.provide_session + @provide_session def manage_slas(self, dag, session=None): """ Finding all tasks that have SLAs defined, and sending alert emails @@ -322,12 +327,11 @@ def manage_slas(self, dag, session=None): self.logger.info(' --------------> ABOUT TO CALL SLA MISS CALL BACK ') dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True - from airflow import ascii email_content = """\ Here's a list of tasks thas missed their SLAs:
{task_list}\n
Blocking tasks: -
{blocking_task_list}\n{ascii.bug}
+
{blocking_task_list}\n{asciiart.bug}
""".format(**locals()) emails = [] for t in dag.tasks: @@ -340,7 +344,7 @@ def manage_slas(self, dag, session=None): if email not in emails: emails.append(email) if emails and len(slas): - utils.send_email( + send_email( emails, "[airflow] SLA miss on DAG=" + dag.dag_id, email_content) @@ -516,7 +520,7 @@ def process_dag(self, dag, executor): session.close() - @utils.provide_session + @provide_session def prioritize_queued(self, session, executor, dagbag): # Prioritizing queued task instances @@ -608,7 +612,7 @@ def signal_handler(signum, frame): sys.exit(1) signal.signal(signal.SIGINT, signal_handler) - utils.pessimistic_connection_handling() + pessimistic_connection_handling() logging.basicConfig(level=logging.DEBUG) self.logger.info("Starting the scheduler") diff --git a/airflow/models.py b/airflow/models.py index 69c6d930fc78c..0a4ceb45970eb 100644 --- a/airflow/models.py +++ b/airflow/models.py @@ -53,9 +53,17 @@ from airflow import settings, utils from airflow.executors import DEFAULT_EXECUTOR, LocalExecutor from airflow import configuration -from airflow.utils import ( - AirflowException, State, apply_defaults, provide_session, - is_container, as_tuple, TriggerRule, LoggingMixin) +from airflow.exceptions import AirflowException +from airflow.utils.dates import cron_presets, date_range as utils_date_range +from airflow.utils.db import provide_session +from airflow.utils.decorators import apply_defaults +from airflow.utils.email import send_email +from airflow.utils.helpers import (as_tuple, is_container, is_in, validate_key) +from airflow.utils.logging import LoggingMixin +from airflow.utils.state import State +from airflow.utils.timeout import timeout +from airflow.utils.trigger_rule import TriggerRule + Base = declarative_base() ID_LEN = 250 @@ -211,13 +219,13 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): return found_dags if (not only_if_updated or - filepath not in self.file_last_changed or - dttm != self.file_last_changed[filepath]): + filepath not in self.file_last_changed or + dttm != self.file_last_changed[filepath]): try: self.logger.info("Importing " + filepath) if mod_name in sys.modules: del sys.modules[mod_name] - with utils.timeout( + with timeout( configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")): m = imp.load_source(mod_name, filepath) except Exception as e: @@ -1067,7 +1075,7 @@ def signal_handler(signum, frame): # if it goes beyond result = None if task_copy.execution_timeout: - with utils.timeout(int( + with timeout(int( task_copy.execution_timeout.total_seconds())): result = task_copy.execute(context=context) @@ -1245,7 +1253,7 @@ def email_alert(self, exception, is_retry=False): "Log file: {self.log_filepath}
" "Mark success: Link
" ).format(**locals()) - utils.send_email(task.email, title, body) + send_email(task.email, title, body) def set_duration(self): if self.end_date and self.start_date: @@ -1531,7 +1539,7 @@ def __init__( *args, **kwargs): - utils.validate_key(task_id) + validate_key(task_id) self.dag_id = dag.dag_id if dag else 'adhoc_' + owner self.task_id = task_id self.owner = owner @@ -1836,7 +1844,7 @@ def get_flat_relatives(self, upstream=False, l=None): if not l: l = [] for t in self.get_direct_relatives(upstream): - if not utils.is_in(t, l): + if not is_in(t, l): l.append(t) t.get_flat_relatives(upstream, l) return l @@ -2100,14 +2108,14 @@ def __init__( self.params.update(self.default_args['params']) del self.default_args['params'] - utils.validate_key(dag_id) + validate_key(dag_id) self.tasks = [] self.dag_id = dag_id self.start_date = start_date self.end_date = end_date self.schedule_interval = schedule_interval - if schedule_interval in utils.cron_presets: - self._schedule_interval = utils.cron_presets.get(schedule_interval) + if schedule_interval in cron_presets: + self._schedule_interval = cron_presets.get(schedule_interval) elif schedule_interval == '@once': self._schedule_interval = None else: @@ -2164,7 +2172,7 @@ def __hash__(self): def date_range(self, start_date, num=None, end_date=datetime.now()): if num: end_date = None - return utils.date_range( + return utils_date_range( start_date=start_date, end_date=end_date, num=num, delta=self._schedule_interval) @@ -2379,7 +2387,7 @@ def roots(self): @provide_session def set_dag_runs_state( self, start_date, end_date, state=State.RUNNING, session=None): - dates = utils.date_range(start_date, end_date) + dates = utils_date_range(start_date, end_date) drs = session.query(DagModel).filter_by(dag_id=self.dag_id).all() for dr in drs: dr.state = State.RUNNING @@ -2436,7 +2444,7 @@ def clear( "You are about to delete these {count} tasks:\n" "{ti_list}\n\n" "Are you sure? (yes/no): ").format(**locals()) - do_it = utils.ask_yesno(question) + do_it = utils.helpers.ask_yesno(question) if do_it: clear_task_instances(tis, session) @@ -2918,6 +2926,7 @@ def __repr__(self): def id_for_date(klass, date, prefix=ID_FORMAT_PREFIX): return prefix.format(date.isoformat()[:19]) + class Pool(Base): __tablename__ = "slot_pool" diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py index 026ee3b13e643..daccbd165a066 100644 --- a/airflow/operators/__init__.py +++ b/airflow/operators/__init__.py @@ -1,6 +1,6 @@ # Imports operators dynamically while keeping the package API clean, # abstracting the underlying modules -from airflow.utils import import_module_attrs as _import_module_attrs +from airflow.utils.helpers import import_module_attrs as _import_module_attrs # These need to be integrated first as other operators depend on them _import_module_attrs(globals(), { diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py index e6925ec3f8cc7..cae53b36cb5d3 100644 --- a/airflow/operators/bash_operator.py +++ b/airflow/operators/bash_operator.py @@ -4,9 +4,10 @@ from subprocess import Popen, STDOUT, PIPE from tempfile import gettempdir, NamedTemporaryFile -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.utils import apply_defaults, TemporaryDirectory +from airflow.utils.decorators import apply_defaults +from airflow.utils.file import TemporaryDirectory class BashOperator(BaseOperator): diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py index fc441b30854b1..0624d915041d3 100644 --- a/airflow/operators/check_operator.py +++ b/airflow/operators/check_operator.py @@ -2,10 +2,10 @@ from builtins import str import logging -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks import BaseHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class CheckOperator(BaseOperator): diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py index 5d397aad0688e..7f8bb53400ad7 100644 --- a/airflow/operators/dagrun_operator.py +++ b/airflow/operators/dagrun_operator.py @@ -2,7 +2,7 @@ import logging from airflow.models import BaseOperator, DagRun -from airflow.utils import apply_defaults, State +from airflow.utils.decorators import apply_defaults from airflow import settings diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index c7b1df7f7b6a7..b01d31ac51042 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -1,7 +1,9 @@ import json import logging +from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.utils import apply_defaults, AirflowException, TemporaryDirectory +from airflow.utils.decorators import apply_defaults +from airflow.utils.file import TemporaryDirectory from docker import Client, tls import ast diff --git a/airflow/operators/dummy_operator.py b/airflow/operators/dummy_operator.py index 6b69115e6b2a9..1392e7d33cc98 100644 --- a/airflow/operators/dummy_operator.py +++ b/airflow/operators/dummy_operator.py @@ -1,5 +1,5 @@ from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class DummyOperator(BaseOperator): diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py index 0cfff08a4bfef..29b18edad0b51 100644 --- a/airflow/operators/email_operator.py +++ b/airflow/operators/email_operator.py @@ -1,6 +1,6 @@ from airflow.models import BaseOperator -from airflow.utils import send_email -from airflow.utils import apply_defaults +from airflow.utils.email import send_email +from airflow.utils.decorators import apply_defaults class EmailOperator(BaseOperator): diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py index 7e99d3e334f65..eab9d61c0f224 100644 --- a/airflow/operators/generic_transfer.py +++ b/airflow/operators/generic_transfer.py @@ -1,7 +1,7 @@ import logging from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults from airflow.hooks.base_hook import BaseHook diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py index 7c0d299e54d51..9a299e1e02160 100644 --- a/airflow/operators/hive_operator.py +++ b/airflow/operators/hive_operator.py @@ -3,7 +3,7 @@ from airflow.hooks import HiveCliHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class HiveOperator(BaseOperator): diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index 09f85e17af105..aadca4de28755 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -4,10 +4,10 @@ import json import logging -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks import PrestoHook, HiveMetastoreHook, MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class HiveStatsCollectionOperator(BaseOperator): diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py index e518ea430ad2e..1346841e6f7a3 100644 --- a/airflow/operators/hive_to_druid.py +++ b/airflow/operators/hive_to_druid.py @@ -2,7 +2,7 @@ from airflow.hooks import HiveCliHook, DruidHook, HiveMetastoreHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class HiveToDruidTransfer(BaseOperator): diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py index bfbe330cfa278..9e27f38516ab5 100644 --- a/airflow/operators/hive_to_mysql.py +++ b/airflow/operators/hive_to_mysql.py @@ -2,7 +2,7 @@ from airflow.hooks import HiveServer2Hook, MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults from tempfile import NamedTemporaryFile diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py index cfa98142ffefd..63881ab981097 100644 --- a/airflow/operators/hive_to_samba_operator.py +++ b/airflow/operators/hive_to_samba_operator.py @@ -3,7 +3,7 @@ from airflow.hooks import HiveServer2Hook, SambaHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class Hive2SambaOperator(BaseOperator): diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py index a9b2ad5e5ee67..87d1415bf625b 100644 --- a/airflow/operators/http_operator.py +++ b/airflow/operators/http_operator.py @@ -1,8 +1,9 @@ import logging +from airflow.exceptions import AirflowException from airflow.hooks import HttpHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults, AirflowException +from airflow.utils.decorators import apply_defaults class SimpleHttpOperator(BaseOperator): diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py index 8793045fba675..5efdaf4e6ba84 100644 --- a/airflow/operators/jdbc_operator.py +++ b/airflow/operators/jdbc_operator.py @@ -4,7 +4,7 @@ from airflow.hooks.jdbc_hook import JdbcHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class JdbcOperator(BaseOperator): diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py index 3dec7cebaf619..1d5273a49105b 100644 --- a/airflow/operators/mssql_operator.py +++ b/airflow/operators/mssql_operator.py @@ -2,7 +2,7 @@ from airflow.hooks import MsSqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class MsSqlOperator(BaseOperator): diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py index 60586de7a3db8..6a981b43c8d97 100644 --- a/airflow/operators/mssql_to_hive.py +++ b/airflow/operators/mssql_to_hive.py @@ -8,7 +8,7 @@ from airflow.hooks import HiveCliHook, MsSqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class MsSqlToHiveTransfer(BaseOperator): diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py index b8d56d5097e38..ae6d36f3278af 100644 --- a/airflow/operators/mysql_operator.py +++ b/airflow/operators/mysql_operator.py @@ -2,7 +2,7 @@ from airflow.hooks import MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class MySqlOperator(BaseOperator): diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py index 6e2a8dd58b242..09ec190f77458 100644 --- a/airflow/operators/mysql_to_hive.py +++ b/airflow/operators/mysql_to_hive.py @@ -7,7 +7,7 @@ from airflow.hooks import HiveCliHook, MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class MySqlToHiveTransfer(BaseOperator): diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py index e0d91afd33067..d25795dec73d7 100644 --- a/airflow/operators/pig_operator.py +++ b/airflow/operators/pig_operator.py @@ -3,7 +3,7 @@ from airflow.hooks import PigCliHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class PigOperator(BaseOperator): diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py index a7302a050b5f0..79fa5e75330de 100644 --- a/airflow/operators/postgres_operator.py +++ b/airflow/operators/postgres_operator.py @@ -2,7 +2,7 @@ from airflow.hooks import PostgresHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class PostgresOperator(BaseOperator): diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py index 9228a93a544c3..e857036415e6e 100644 --- a/airflow/operators/presto_check_operator.py +++ b/airflow/operators/presto_check_operator.py @@ -1,6 +1,6 @@ from airflow.hooks import PrestoHook from airflow.operators import CheckOperator, ValueCheckOperator, IntervalCheckOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class PrestoCheckOperator(CheckOperator): diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py index 37c3caadcb8c5..29de0c7d86655 100644 --- a/airflow/operators/presto_to_mysql.py +++ b/airflow/operators/presto_to_mysql.py @@ -2,7 +2,7 @@ from airflow.hooks import PrestoHook, MySqlHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class PrestoToMySqlTransfer(BaseOperator): diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index 1cf7eed8e5a56..290cc65d139e9 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -3,7 +3,8 @@ import logging from airflow.models import BaseOperator, TaskInstance -from airflow.utils import apply_defaults, State +from airflow.utils.state import State +from airflow.utils.decorators import apply_defaults from airflow import settings diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py index 837c2f902434c..ce36b00efc57d 100644 --- a/airflow/operators/s3_file_transform_operator.py +++ b/airflow/operators/s3_file_transform_operator.py @@ -2,10 +2,10 @@ from tempfile import NamedTemporaryFile import subprocess -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks import S3Hook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class S3FileTransformOperator(BaseOperator): diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py index 20d23e781e3d8..3fc5327a40743 100644 --- a/airflow/operators/s3_to_hive_operator.py +++ b/airflow/operators/s3_to_hive_operator.py @@ -3,10 +3,10 @@ import logging from tempfile import NamedTemporaryFile -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.hooks import HiveCliHook, S3Hook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class S3ToHiveTransfer(BaseOperator): diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py index 4c45962db0208..7d7a7c689ce6d 100644 --- a/airflow/operators/sensors.py +++ b/airflow/operators/sensors.py @@ -8,11 +8,11 @@ from time import sleep from airflow import hooks, settings +from airflow.exceptions import AirflowException, AirflowSensorTimeout from airflow.models import BaseOperator, TaskInstance, Connection as DB from airflow.hooks import BaseHook -from airflow.utils import State -from airflow.utils import ( - apply_defaults, AirflowException, AirflowSensorTimeout) +from airflow.utils.state import State +from airflow.utils.decorators import apply_defaults class BaseSensorOperator(BaseOperator): diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py index c8734fb954431..2f173d7edf34b 100644 --- a/airflow/operators/slack_operator.py +++ b/airflow/operators/slack_operator.py @@ -1,6 +1,7 @@ from slackclient import SlackClient from airflow.models import BaseOperator -from airflow.utils import apply_defaults, AirflowException +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException import json import logging diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py index ebdba2f5ce725..700019d9ead8b 100644 --- a/airflow/operators/sqlite_operator.py +++ b/airflow/operators/sqlite_operator.py @@ -2,7 +2,7 @@ from airflow.hooks import SqliteHook from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults class SqliteOperator(BaseOperator): diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py index 54c2409d7e73b..c56e7afc54066 100644 --- a/airflow/operators/subdag_operator.py +++ b/airflow/operators/subdag_operator.py @@ -1,6 +1,6 @@ -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.utils import apply_defaults +from airflow.utils.decorators import apply_defaults from airflow.executors import DEFAULT_EXECUTOR diff --git a/airflow/settings.py b/airflow/settings.py index 51dfe4d153717..ae56455649fe5 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -109,6 +109,7 @@ def policy(task_instance): """ pass + def configure_logging(): logging.root.handlers = [] logging.basicConfig( diff --git a/airflow/utils.py b/airflow/utils.py deleted file mode 100644 index 228602e0d8888..0000000000000 --- a/airflow/utils.py +++ /dev/null @@ -1,978 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import sys -from builtins import str, input, object -from past.builtins import basestring -from copy import copy -from datetime import datetime, date, timedelta -from dateutil.relativedelta import relativedelta # for doctest -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart -from email.mime.application import MIMEApplication -from email.utils import formatdate -import errno -from functools import wraps -import imp -import importlib -import inspect -import json -import logging -import os -import re -import shutil -import signal -import six -import smtplib -from tempfile import mkdtemp - -from alembic.config import Config -from alembic import command -from alembic.migration import MigrationContext - -from contextlib import contextmanager - -from sqlalchemy import event, exc -from sqlalchemy.pool import Pool - -import numpy as np -from croniter import croniter - -from airflow import settings -from airflow import configuration - - -class AirflowException(Exception): - pass - - -class AirflowSensorTimeout(Exception): - pass - - -class TriggerRule(object): - ALL_SUCCESS = 'all_success' - ALL_FAILED = 'all_failed' - ALL_DONE = 'all_done' - ONE_SUCCESS = 'one_success' - ONE_FAILED = 'one_failed' - DUMMY = 'dummy' - - @classmethod - def is_valid(cls, trigger_rule): - return trigger_rule in cls.all_triggers() - - @classmethod - def all_triggers(cls): - return [getattr(cls, attr) - for attr in dir(cls) - if not attr.startswith("__") and not callable(getattr(cls, attr))] - - -class State(object): - """ - Static class with task instance states constants and color method to - avoid hardcoding. - """ - QUEUED = "queued" - RUNNING = "running" - SUCCESS = "success" - SHUTDOWN = "shutdown" # External request to shut down - FAILED = "failed" - UP_FOR_RETRY = "up_for_retry" - UPSTREAM_FAILED = "upstream_failed" - SKIPPED = "skipped" - - state_color = { - QUEUED: 'gray', - RUNNING: 'lime', - SUCCESS: 'green', - SHUTDOWN: 'blue', - FAILED: 'red', - UP_FOR_RETRY: 'gold', - UPSTREAM_FAILED: 'orange', - SKIPPED: 'pink', - } - - @classmethod - def color(cls, state): - if state in cls.state_color: - return cls.state_color[state] - else: - return 'white' - - @classmethod - def color_fg(cls, state): - color = cls.color(state) - if color in ['green', 'red']: - return 'white' - else: - return 'black' - - @classmethod - def runnable(cls): - return [ - None, cls.FAILED, cls.UP_FOR_RETRY, cls.UPSTREAM_FAILED, - cls.SKIPPED, cls.QUEUED] - - -cron_presets = { - '@hourly': '0 * * * *', - '@daily': '0 0 * * *', - '@weekly': '0 0 * * 0', - '@monthly': '0 0 1 * *', - '@yearly': '0 0 1 1 *', -} - -def provide_session(func): - """ - Function decorator that provides a session if it isn't provided. - If you want to reuse a session or run the function as part of a - database transaction, you pass it to the function, if not this wrapper - will create one and close it for you. - """ - @wraps(func) - def wrapper(*args, **kwargs): - needs_session = False - if 'session' not in kwargs: - needs_session = True - session = settings.Session() - kwargs['session'] = session - result = func(*args, **kwargs) - if needs_session: - session.expunge_all() - session.commit() - session.close() - return result - return wrapper - - -def pessimistic_connection_handling(): - @event.listens_for(Pool, "checkout") - def ping_connection(dbapi_connection, connection_record, connection_proxy): - ''' - Disconnect Handling - Pessimistic, taken from: - http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html - ''' - cursor = dbapi_connection.cursor() - try: - cursor.execute("SELECT 1") - except: - raise exc.DisconnectionError() - cursor.close() - -@provide_session -def merge_conn(conn, session=None): - from airflow import models - C = models.Connection - if not session.query(C).filter(C.conn_id == conn.conn_id).first(): - session.add(conn) - session.commit() - - -def initdb(): - session = settings.Session() - - from airflow import models - upgradedb() - - merge_conn( - models.Connection( - conn_id='airflow_db', conn_type='mysql', - host='localhost', login='root', - schema='airflow')) - merge_conn( - models.Connection( - conn_id='airflow_ci', conn_type='mysql', - host='localhost', login='root', - schema='airflow_ci')) - merge_conn( - models.Connection( - conn_id='beeline_default', conn_type='beeline', port="10000", - host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}", - schema='default')) - merge_conn( - models.Connection( - conn_id='bigquery_default', conn_type='bigquery')) - merge_conn( - models.Connection( - conn_id='local_mysql', conn_type='mysql', - host='localhost', login='airflow', password='airflow', - schema='airflow')) - merge_conn( - models.Connection( - conn_id='presto_default', conn_type='presto', - host='localhost', - schema='hive', port=3400)) - merge_conn( - models.Connection( - conn_id='hive_cli_default', conn_type='hive_cli', - schema='default',)) - merge_conn( - models.Connection( - conn_id='hiveserver2_default', conn_type='hiveserver2', - host='localhost', - schema='default', port=10000)) - merge_conn( - models.Connection( - conn_id='metastore_default', conn_type='hive_metastore', - host='localhost', extra="{\"authMechanism\": \"PLAIN\"}", - port=9083)) - merge_conn( - models.Connection( - conn_id='mysql_default', conn_type='mysql', - login='root', - host='localhost')) - merge_conn( - models.Connection( - conn_id='postgres_default', conn_type='postgres', - login='postgres', - schema='airflow', - host='localhost')) - merge_conn( - models.Connection( - conn_id='sqlite_default', conn_type='sqlite', - host='/tmp/sqlite_default.db')) - merge_conn( - models.Connection( - conn_id='http_default', conn_type='http', - host='https://www.google.com/')) - merge_conn( - models.Connection( - conn_id='mssql_default', conn_type='mssql', - host='localhost', port=1433)) - merge_conn( - models.Connection( - conn_id='vertica_default', conn_type='vertica', - host='localhost', port=5433)) - merge_conn( - models.Connection( - conn_id='webhdfs_default', conn_type='hdfs', - host='localhost', port=50070)) - merge_conn( - models.Connection( - conn_id='ssh_default', conn_type='ssh', - host='localhost')) - - # Known event types - KET = models.KnownEventType - if not session.query(KET).filter(KET.know_event_type == 'Holiday').first(): - session.add(KET(know_event_type='Holiday')) - if not session.query(KET).filter(KET.know_event_type == 'Outage').first(): - session.add(KET(know_event_type='Outage')) - if not session.query(KET).filter( - KET.know_event_type == 'Natural Disaster').first(): - session.add(KET(know_event_type='Natural Disaster')) - if not session.query(KET).filter( - KET.know_event_type == 'Marketing Campaign').first(): - session.add(KET(know_event_type='Marketing Campaign')) - session.commit() - - models.DagBag(sync_to_db=True) - - Chart = models.Chart - chart_label = "Airflow task instance by type" - chart = session.query(Chart).filter(Chart.label == chart_label).first() - if not chart: - chart = Chart( - label=chart_label, - conn_id='airflow_db', - chart_type='bar', - x_is_date=False, - sql=( - "SELECT state, COUNT(1) as number " - "FROM task_instance " - "WHERE dag_id LIKE 'example%' " - "GROUP BY state"), - ) - session.add(chart) - - -def upgradedb(): - logging.info("Creating tables") - package_dir = os.path.abspath(os.path.dirname(__file__)) - directory = os.path.join(package_dir, 'migrations') - config = Config(os.path.join(package_dir, 'alembic.ini')) - config.set_main_option('script_location', directory) - config.set_main_option('sqlalchemy.url', - configuration.get('core', 'SQL_ALCHEMY_CONN')) - command.upgrade(config, 'heads') - - -def resetdb(): - ''' - Clear out the database - ''' - from airflow import models - - logging.info("Dropping tables that exist") - models.Base.metadata.drop_all(settings.engine) - mc = MigrationContext.configure(settings.engine) - if mc._version.exists(settings.engine): - mc._version.drop(settings.engine) - initdb() - - -def validate_key(k, max_length=250): - if not isinstance(k, basestring): - raise TypeError("The key has to be a string") - elif len(k) > max_length: - raise AirflowException( - "The key has to be less than {0} characters".format(max_length)) - elif not re.match(r'^[A-Za-z0-9_\-\.]+$', k): - raise AirflowException( - "The key ({k}) has to be made of alphanumeric characters, dashes, " - "dots and underscores exclusively".format(**locals())) - else: - return True - - -def date_range( - start_date, - end_date=None, - num=None, - delta=None): - """ - Get a set of dates as a list based on a start, end and delta, delta - can be something that can be added to ``datetime.datetime`` - or a cron expression as a ``str`` - - :param start_date: anchor date to start the series from - :type start_date: datetime.datetime - :param end_date: right boundary for the date range - :type end_date: datetime.datetime - :param num: alternatively to end_date, you can specify the number of - number of entries you want in the range. This number can be negative, - output will always be sorted regardless - :type num: int - - >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1)) - [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] - >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *') - [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] - >>> date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *") - [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0), datetime.datetime(2016, 3, 1, 0, 0)] - """ - if not delta: - return [] - if end_date and start_date > end_date: - raise Exception("Wait. start_date needs to be before end_date") - if end_date and num: - raise Exception("Wait. Either specify end_date OR num") - if not end_date and not num: - end_date = datetime.now() - - delta_iscron = False - if isinstance(delta, six.string_types): - delta_iscron = True - cron = croniter(delta, start_date) - elif isinstance(delta, timedelta): - delta = abs(delta) - l = [] - if end_date: - while start_date <= end_date: - l.append(start_date) - if delta_iscron: - start_date = cron.get_next(datetime) - else: - start_date += delta - else: - for i in range(abs(num)): - l.append(start_date) - if delta_iscron: - if num > 0: - start_date = cron.get_next(datetime) - else: - start_date = cron.get_prev(datetime) - else: - if num > 0: - start_date += delta - else: - start_date -= delta - return sorted(l) - - -def json_ser(obj): - """ - json serializer that deals with dates - usage: json.dumps(object, default=utils.json_ser) - """ - if isinstance(obj, (datetime, date)): - return obj.isoformat() - - -def alchemy_to_dict(obj): - """ - Transforms a SQLAlchemy model instance into a dictionary - """ - if not obj: - return None - d = {} - for c in obj.__table__.columns: - value = getattr(obj, c.name) - if type(value) == datetime: - value = value.isoformat() - d[c.name] = value - return d - - -def readfile(filepath): - f = open(filepath) - content = f.read() - f.close() - return content - - -def apply_defaults(func): - """ - Function decorator that Looks for an argument named "default_args", and - fills the unspecified arguments from it. - - Since python2.* isn't clear about which arguments are missing when - calling a function, and that this can be quite confusing with multi-level - inheritance and argument defaults, this decorator also alerts with - specific information about the missing arguments. - """ - @wraps(func) - def wrapper(*args, **kwargs): - if len(args) > 1: - raise AirflowException( - "Use keyword arguments when initializing operators") - dag_args = {} - dag_params = {} - if 'dag' in kwargs and kwargs['dag']: - dag = kwargs['dag'] - dag_args = copy(dag.default_args) or {} - dag_params = copy(dag.params) or {} - - params = {} - if 'params' in kwargs: - params = kwargs['params'] - dag_params.update(params) - - default_args = {} - if 'default_args' in kwargs: - default_args = kwargs['default_args'] - if 'params' in default_args: - dag_params.update(default_args['params']) - del default_args['params'] - - dag_args.update(default_args) - default_args = dag_args - arg_spec = inspect.getargspec(func) - num_defaults = len(arg_spec.defaults) if arg_spec.defaults else 0 - non_optional_args = arg_spec.args[:-num_defaults] - if 'self' in non_optional_args: - non_optional_args.remove('self') - for arg in func.__code__.co_varnames: - if arg in default_args and arg not in kwargs: - kwargs[arg] = default_args[arg] - missing_args = list(set(non_optional_args) - set(kwargs)) - if missing_args: - msg = "Argument {0} is required".format(missing_args) - raise AirflowException(msg) - - kwargs['params'] = dag_params - - result = func(*args, **kwargs) - return result - return wrapper - -if 'BUILDING_AIRFLOW_DOCS' in os.environ: - # Monkey patch hook to get good function headers while building docs - apply_defaults = lambda x: x - -def ask_yesno(question): - yes = set(['yes', 'y']) - no = set(['no', 'n']) - - done = False - print(question) - while not done: - choice = input().lower() - if choice in yes: - return True - elif choice in no: - return False - else: - print("Please respond by yes or no.") - - -def send_email(to, subject, html_content, files=None, dryrun=False): - """ - Send email using backend specified in EMAIL_BACKEND. - """ - path, attr = configuration.get('email', 'EMAIL_BACKEND').rsplit('.', 1) - module = importlib.import_module(path) - backend = getattr(module, attr) - return backend(to, subject, html_content, files=files, dryrun=dryrun) - - -def send_email_smtp(to, subject, html_content, files=None, dryrun=False): - """ - Send an email with html content - - >>> send_email('test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True) - """ - SMTP_MAIL_FROM = configuration.get('smtp', 'SMTP_MAIL_FROM') - - if isinstance(to, basestring): - if ',' in to: - to = to.split(',') - elif ';' in to: - to = to.split(';') - else: - to = [to] - - msg = MIMEMultipart('alternative') - msg['Subject'] = subject - msg['From'] = SMTP_MAIL_FROM - msg['To'] = ", ".join(to) - msg["Date"] = formatdate(localtime=True) - mime_text = MIMEText(html_content, 'html') - msg.attach(mime_text) - - for fname in files or []: - basename = os.path.basename(fname) - with open(fname, "rb") as f: - msg.attach(MIMEApplication( - f.read(), - Content_Disposition='attachment; filename="%s"' % basename, - Name=basename - )) - - send_MIME_email(SMTP_MAIL_FROM, to, msg, dryrun) - - -def send_MIME_email(e_from, e_to, mime_msg, dryrun=False): - SMTP_HOST = configuration.get('smtp', 'SMTP_HOST') - SMTP_PORT = configuration.getint('smtp', 'SMTP_PORT') - SMTP_USER = configuration.get('smtp', 'SMTP_USER') - SMTP_PASSWORD = configuration.get('smtp', 'SMTP_PASSWORD') - SMTP_STARTTLS = configuration.getboolean('smtp', 'SMTP_STARTTLS') - SMTP_SSL = configuration.getboolean('smtp', 'SMTP_SSL') - - if not dryrun: - s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else smtplib.SMTP(SMTP_HOST, SMTP_PORT) - if SMTP_STARTTLS: - s.starttls() - if SMTP_USER and SMTP_PASSWORD: - s.login(SMTP_USER, SMTP_PASSWORD) - logging.info("Sent an alert email to " + str(e_to)) - s.sendmail(e_from, e_to, mime_msg.as_string()) - s.quit() - - -def import_module_attrs(parent_module_globals, module_attrs_dict): - ''' - Attempts to import a set of modules and specified attributes in the - form of a dictionary. The attributes are copied in the parent module's - namespace. The function returns a list of attributes names that can be - affected to __all__. - - This is used in the context of ``operators`` and ``hooks`` and - silence the import errors for when libraries are missing. It makes - for a clean package abstracting the underlying modules and only - brings functional operators to those namespaces. - ''' - imported_attrs = [] - for mod, attrs in list(module_attrs_dict.items()): - try: - path = os.path.realpath(parent_module_globals['__file__']) - folder = os.path.dirname(path) - f, filename, description = imp.find_module(mod, [folder]) - module = imp.load_module(mod, f, filename, description) - for attr in attrs: - parent_module_globals[attr] = getattr(module, attr) - imported_attrs += [attr] - except Exception as err: - logging.debug("Error importing module {mod}: {err}".format( - mod=mod, err=err)) - return imported_attrs - - -def is_in(obj, l): - """ - Checks whether an object is one of the item in the list. - This is different from ``in`` because ``in`` uses __cmp__ when - present. Here we change based on the object itself - """ - for item in l: - if item is obj: - return True - return False - - -@contextmanager -def TemporaryDirectory(suffix='', prefix=None, dir=None): - name = mkdtemp(suffix=suffix, prefix=prefix, dir=dir) - try: - yield name - finally: - try: - shutil.rmtree(name) - except OSError as e: - # ENOENT - no such file or directory - if e.errno != errno.ENOENT: - raise e - - -class AirflowTaskTimeout(Exception): - pass - - -class timeout(object): - """ - To be used in a ``with`` block and timeout its content. - """ - def __init__(self, seconds=1, error_message='Timeout'): - self.seconds = seconds - self.error_message = error_message - - def handle_timeout(self, signum, frame): - logging.error("Process timed out") - raise AirflowTaskTimeout(self.error_message) - - def __enter__(self): - try: - signal.signal(signal.SIGALRM, self.handle_timeout) - signal.alarm(self.seconds) - except ValueError as e: - logging.warning("timeout can't be used in the current context") - logging.exception(e) - - def __exit__(self, type, value, traceback): - try: - signal.alarm(0) - except ValueError as e: - logging.warning("timeout can't be used in the current context") - logging.exception(e) - - -def is_container(obj): - """ - Test if an object is a container (iterable) but not a string - """ - return hasattr(obj, '__iter__') and not isinstance(obj, basestring) - - -def as_tuple(obj): - """ - If obj is a container, returns obj as a tuple. - Otherwise, returns a tuple containing obj. - """ - if is_container(obj): - return tuple(obj) - else: - return tuple([obj]) - - -def round_time(dt, delta, start_date=datetime.min): - """ - Returns the datetime of the form start_date + i * delta - which is closest to dt for any non-negative integer i. - - Note that delta may be a datetime.timedelta or a dateutil.relativedelta - - >>> round_time(datetime(2015, 1, 1, 6), timedelta(days=1)) - datetime.datetime(2015, 1, 1, 0, 0) - >>> round_time(datetime(2015, 1, 2), relativedelta(months=1)) - datetime.datetime(2015, 1, 1, 0, 0) - >>> round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) - datetime.datetime(2015, 9, 16, 0, 0) - >>> round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) - datetime.datetime(2015, 9, 15, 0, 0) - >>> round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) - datetime.datetime(2015, 9, 14, 0, 0) - >>> round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) - datetime.datetime(2015, 9, 14, 0, 0) - """ - - if isinstance(delta, six.string_types): - # It's cron based, so it's easy - cron = croniter(delta, start_date) - prev = cron.get_prev(datetime) - if prev == start_date: - return start_date - else: - return prev - - # Ignore the microseconds of dt - dt -= timedelta(microseconds = dt.microsecond) - - # We are looking for a datetime in the form start_date + i * delta - # which is as close as possible to dt. Since delta could be a relative - # delta we don't know it's exact length in seconds so we cannot rely on - # division to find i. Instead we employ a binary search algorithm, first - # finding an upper and lower limit and then disecting the interval until - # we have found the closest match. - - # We first search an upper limit for i for which start_date + upper * delta - # exceeds dt. - upper = 1 - while start_date + upper*delta < dt: - # To speed up finding an upper limit we grow this exponentially by a - # factor of 2 - upper *= 2 - - # Since upper is the first value for which start_date + upper * delta - # exceeds dt, upper // 2 is below dt and therefore forms a lower limited - # for the i we are looking for - lower = upper // 2 - - # We now continue to intersect the interval between - # start_date + lower * delta and start_date + upper * delta - # until we find the closest value - while True: - # Invariant: start + lower * delta < dt <= start + upper * delta - # If start_date + (lower + 1)*delta exceeds dt, then either lower or - # lower+1 has to be the solution we are searching for - if start_date + (lower + 1)*delta >= dt: - # Check if start_date + (lower + 1)*delta or - # start_date + lower*delta is closer to dt and return the solution - if ( - (start_date + (lower + 1) * delta) - dt <= - dt - (start_date + lower * delta)): - return start_date + (lower + 1)*delta - else: - return start_date + lower * delta - - # We intersect the interval and either replace the lower or upper - # limit with the candidate - candidate = lower + (upper - lower) // 2 - if start_date + candidate*delta >= dt: - upper = candidate - else: - lower = candidate - - # in the special case when start_date > dt the search for upper will - # immediately stop for upper == 1 which results in lower = upper // 2 = 0 - # and this function returns start_date. - - -def chain(*tasks): - """ - Given a number of tasks, builds a dependency chain. - - chain(task_1, task_2, task_3, task_4) - - is equivalent to - - task_1.set_downstream(task_2) - task_2.set_downstream(task_3) - task_3.set_downstream(task_4) - """ - for up_task, down_task in zip(tasks[:-1], tasks[1:]): - up_task.set_downstream(down_task) - - -class AirflowJsonEncoder(json.JSONEncoder): - def default(self, obj): - # convert dates and numpy objects in a json serializable format - if isinstance(obj, datetime): - return obj.strftime('%Y-%m-%dT%H:%M:%SZ') - elif isinstance(obj, date): - return obj.strftime('%Y-%m-%d') - elif type(obj) in [np.int_, np.intc, np.intp, np.int8, np.int16, - np.int32, np.int64, np.uint8, np.uint16, - np.uint32, np.uint64]: - return int(obj) - elif type(obj) in [np.bool_]: - return bool(obj) - elif type(obj) in [np.float_, np.float16, np.float32, np.float64, - np.complex_, np.complex64, np.complex128]: - return float(obj) - - # Let the base class default method raise the TypeError - return json.JSONEncoder.default(self, obj) - - -class LoggingMixin(object): - """ - Convenience super-class to have a logger configured with the class name - """ - - @property - def logger(self): - try: - return self._logger - except AttributeError: - self._logger = logging.root.getChild(self.__class__.__module__ + '.' +self.__class__.__name__) - return self._logger - - -class S3Log(object): - """ - Utility class for reading and writing logs in S3. - Requires airflow[s3] and setting the REMOTE_BASE_LOG_FOLDER and - REMOTE_LOG_CONN_ID configuration options in airflow.cfg. - """ - def __init__(self): - remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') - try: - from airflow.hooks import S3Hook - self.hook = S3Hook(remote_conn_id) - except: - self.hook = None - logging.error( - 'Could not create an S3Hook with connection id "{}". ' - 'Please make sure that airflow[s3] is installed and ' - 'the S3 connection exists.'.format(remote_conn_id)) - - def read(self, remote_log_location, return_error=False): - """ - Returns the log found at the remote_log_location. Returns '' if no - logs are found or there is an error. - - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param return_error: if True, returns a string error message if an - error occurs. Otherwise returns '' when an error occurs. - :type return_error: bool - """ - if self.hook: - try: - s3_key = self.hook.get_key(remote_log_location) - if s3_key: - return s3_key.get_contents_as_string().decode() - except: - pass - - # raise/return error if we get here - err = 'Could not read logs from {}'.format(remote_log_location) - logging.error(err) - return err if return_error else '' - - - def write(self, log, remote_log_location, append=False): - """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. - - :param log: the log to write to the remote_log_location - :type log: string - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param append: if False, any existing log file is overwritten. If True, - the new log is appended to any existing logs. - :type append: bool - - """ - if self.hook: - - if append: - old_log = self.read(remote_log_location) - log = old_log + '\n' + log - try: - self.hook.load_string( - log, - key=remote_log_location, - replace=True, - encrypt=configuration.get('core', 'ENCRYPT_S3_LOGS')) - return - except: - pass - - # raise/return error if we get here - logging.error('Could not write logs to {}'.format(remote_log_location)) - - -class GCSLog(object): - """ - Utility class for reading and writing logs in GCS. - Requires either airflow[gcloud] or airflow[gcp_api] and - setting the REMOTE_BASE_LOG_FOLDER and REMOTE_LOG_CONN_ID configuration - options in airflow.cfg. - """ - def __init__(self): - """ - Attempt to create hook with airflow[gcloud] (and set - use_gcloud = True), otherwise uses airflow[gcp_api] - """ - remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') - self.use_gcloud = False - - try: - from airflow.contrib.hooks import GCSHook - self.hook = GCSHook(remote_conn_id) - self.use_gcloud = True - except: - try: - from airflow.contrib.hooks import GoogleCloudStorageHook - self.hook = GoogleCloudStorageHook(remote_conn_id) - except: - self.hook = None - logging.error( - 'Could not create a GCSHook with connection id "{}". ' - 'Please make sure that either airflow[gcloud] or ' - 'airflow[gcp_api] is installed and the GCS connection ' - 'exists.'.format(remote_conn_id)) - - def read(self, remote_log_location, return_error=True): - """ - Returns the log found at the remote_log_location. - - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param return_error: if True, returns a string error message if an - error occurs. Otherwise returns '' when an error occurs. - :type return_error: bool - """ - if self.hook: - try: - if self.use_gcloud: - gcs_blob = self.hook.get_blob(remote_log_location) - if gcs_blob: - return gcs_blob.download_as_string().decode() - else: - bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1) - return self.hook.download(bkt, blob).decode() - except: - pass - - # raise/return error if we get here - err = 'Could not read logs from {}'.format(remote_log_location) - logging.error(err) - return err if return_error else '' - - def write(self, log, remote_log_location, append=False): - """ - Writes the log to the remote_log_location. Fails silently if no hook - was created. - - :param log: the log to write to the remote_log_location - :type log: string - :param remote_log_location: the log's location in remote storage - :type remote_log_location: string (path) - :param append: if False, any existing log file is overwritten. If True, - the new log is appended to any existing logs. - :type append: bool - - """ - if self.hook: - - if append: - old_log = self.read(remote_log_location) - log = old_log + '\n' + log - - try: - if self.use_gcloud: - self.hook.upload_from_string( - log, - blob=remote_log_location, - replace=True) - return - else: - bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1) - from tempfile import NamedTemporaryFile - with NamedTemporaryFile(mode='w+') as tmpfile: - tmpfile.write(log) - self.hook.upload(bkt, blob, tmpfile.name) - return - except: - pass - - # raise/return error if we get here - logging.error('Could not write logs to {}'.format(remote_log_location)) diff --git a/airflow/utils/__init__.py b/airflow/utils/__init__.py new file mode 100644 index 0000000000000..248dba00847de --- /dev/null +++ b/airflow/utils/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import + +import warnings + +from .decorators import apply_defaults as _apply_defaults + + +def apply_defaults(func): + warnings.warn_explicit( + """ + You are importing apply_defaults from airflow.utils which + will be deprecated in a future version. + Please use : + + from airflow.utils.decorators import apply_defaults + """, + category=PendingDeprecationWarning, + filename=func.func_code.co_filename, + lineno=func.func_code.co_firstlineno + 1 + ) + return _apply_defaults(func) diff --git a/airflow/ascii.py b/airflow/utils/asciiart.py similarity index 75% rename from airflow/ascii.py rename to airflow/utils/asciiart.py index 60e02482bae96..9bebb5d1bc7e8 100644 --- a/airflow/ascii.py +++ b/airflow/utils/asciiart.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# bug = r"""\ =, .= =.| ,---. |.= @@ -39,4 +53,3 @@ (/ / // /|//||||\\ \ \ \ _) ------------------------------------------------------------------------------- """ - diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py new file mode 100644 index 0000000000000..cd9aab576a146 --- /dev/null +++ b/airflow/utils/dates.py @@ -0,0 +1,181 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from datetime import datetime, date, timedelta +from dateutil.relativedelta import relativedelta # for doctest +import six + +from croniter import croniter + + +cron_presets = { + '@hourly': '0 * * * *', + '@daily': '0 0 * * *', + '@weekly': '0 0 * * 0', + '@monthly': '0 0 1 * *', + '@yearly': '0 0 1 1 *', +} + + +def date_range( + start_date, + end_date=None, + num=None, + delta=None): + """ + Get a set of dates as a list based on a start, end and delta, delta + can be something that can be added to ``datetime.datetime`` + or a cron expression as a ``str`` + + :param start_date: anchor date to start the series from + :type start_date: datetime.datetime + :param end_date: right boundary for the date range + :type end_date: datetime.datetime + :param num: alternatively to end_date, you can specify the number of + number of entries you want in the range. This number can be negative, + output will always be sorted regardless + :type num: int + + >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta=timedelta(1)) + [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] + >>> date_range(datetime(2016, 1, 1), datetime(2016, 1, 3), delta='0 0 * * *') + [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 1, 2, 0, 0), datetime.datetime(2016, 1, 3, 0, 0)] + >>> date_range(datetime(2016, 1, 1), datetime(2016, 3, 3), delta="0 0 0 * *") + [datetime.datetime(2016, 1, 1, 0, 0), datetime.datetime(2016, 2, 1, 0, 0), datetime.datetime(2016, 3, 1, 0, 0)] + """ + if not delta: + return [] + if end_date and start_date > end_date: + raise Exception("Wait. start_date needs to be before end_date") + if end_date and num: + raise Exception("Wait. Either specify end_date OR num") + if not end_date and not num: + end_date = datetime.now() + + delta_iscron = False + if isinstance(delta, six.string_types): + delta_iscron = True + cron = croniter(delta, start_date) + elif isinstance(delta, timedelta): + delta = abs(delta) + l = [] + if end_date: + while start_date <= end_date: + l.append(start_date) + if delta_iscron: + start_date = cron.get_next(datetime) + else: + start_date += delta + else: + for i in range(abs(num)): + l.append(start_date) + if delta_iscron: + if num > 0: + start_date = cron.get_next(datetime) + else: + start_date = cron.get_prev(datetime) + else: + if num > 0: + start_date += delta + else: + start_date -= delta + return sorted(l) + + +def round_time(dt, delta, start_date=datetime.min): + """ + Returns the datetime of the form start_date + i * delta + which is closest to dt for any non-negative integer i. + + Note that delta may be a datetime.timedelta or a dateutil.relativedelta + + >>> round_time(datetime(2015, 1, 1, 6), timedelta(days=1)) + datetime.datetime(2015, 1, 1, 0, 0) + >>> round_time(datetime(2015, 1, 2), relativedelta(months=1)) + datetime.datetime(2015, 1, 1, 0, 0) + >>> round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) + datetime.datetime(2015, 9, 16, 0, 0) + >>> round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) + datetime.datetime(2015, 9, 15, 0, 0) + >>> round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) + datetime.datetime(2015, 9, 14, 0, 0) + >>> round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(2015, 9, 14, 0, 0)) + datetime.datetime(2015, 9, 14, 0, 0) + """ + + if isinstance(delta, six.string_types): + # It's cron based, so it's easy + cron = croniter(delta, start_date) + prev = cron.get_prev(datetime) + if prev == start_date: + return start_date + else: + return prev + + # Ignore the microseconds of dt + dt -= timedelta(microseconds=dt.microsecond) + + # We are looking for a datetime in the form start_date + i * delta + # which is as close as possible to dt. Since delta could be a relative + # delta we don't know it's exact length in seconds so we cannot rely on + # division to find i. Instead we employ a binary search algorithm, first + # finding an upper and lower limit and then disecting the interval until + # we have found the closest match. + + # We first search an upper limit for i for which start_date + upper * delta + # exceeds dt. + upper = 1 + while start_date + upper*delta < dt: + # To speed up finding an upper limit we grow this exponentially by a + # factor of 2 + upper *= 2 + + # Since upper is the first value for which start_date + upper * delta + # exceeds dt, upper // 2 is below dt and therefore forms a lower limited + # for the i we are looking for + lower = upper // 2 + + # We now continue to intersect the interval between + # start_date + lower * delta and start_date + upper * delta + # until we find the closest value + while True: + # Invariant: start + lower * delta < dt <= start + upper * delta + # If start_date + (lower + 1)*delta exceeds dt, then either lower or + # lower+1 has to be the solution we are searching for + if start_date + (lower + 1)*delta >= dt: + # Check if start_date + (lower + 1)*delta or + # start_date + lower*delta is closer to dt and return the solution + if ( + (start_date + (lower + 1) * delta) - dt <= + dt - (start_date + lower * delta)): + return start_date + (lower + 1)*delta + else: + return start_date + lower * delta + + # We intersect the interval and either replace the lower or upper + # limit with the candidate + candidate = lower + (upper - lower) // 2 + if start_date + candidate*delta >= dt: + upper = candidate + else: + lower = candidate + + # in the special case when start_date > dt the search for upper will + # immediately stop for upper == 1 which results in lower = upper // 2 = 0 + # and this function returns start_date. diff --git a/airflow/utils/db.py b/airflow/utils/db.py new file mode 100644 index 0000000000000..03c9823e788ce --- /dev/null +++ b/airflow/utils/db.py @@ -0,0 +1,223 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from functools import wraps +import logging +import os + +from alembic.config import Config +from alembic import command +from alembic.migration import MigrationContext + +from sqlalchemy import event, exc +from sqlalchemy.pool import Pool + +from airflow import settings +from airflow import configuration + + +def provide_session(func): + """ + Function decorator that provides a session if it isn't provided. + If you want to reuse a session or run the function as part of a + database transaction, you pass it to the function, if not this wrapper + will create one and close it for you. + """ + @wraps(func) + def wrapper(*args, **kwargs): + needs_session = False + if 'session' not in kwargs: + needs_session = True + session = settings.Session() + kwargs['session'] = session + result = func(*args, **kwargs) + if needs_session: + session.expunge_all() + session.commit() + session.close() + return result + return wrapper + + +def pessimistic_connection_handling(): + @event.listens_for(Pool, "checkout") + def ping_connection(dbapi_connection, connection_record, connection_proxy): + ''' + Disconnect Handling - Pessimistic, taken from: + http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html + ''' + cursor = dbapi_connection.cursor() + try: + cursor.execute("SELECT 1") + except: + raise exc.DisconnectionError() + cursor.close() + + +@provide_session +def merge_conn(conn, session=None): + from airflow import models + C = models.Connection + if not session.query(C).filter(C.conn_id == conn.conn_id).first(): + session.add(conn) + session.commit() + + +def initdb(): + session = settings.Session() + + from airflow import models + upgradedb() + + merge_conn( + models.Connection( + conn_id='airflow_db', conn_type='mysql', + host='localhost', login='root', password='', + schema='airflow')) + merge_conn( + models.Connection( + conn_id='airflow_ci', conn_type='mysql', + host='localhost', login='root', + schema='airflow_ci')) + merge_conn( + models.Connection( + conn_id='beeline_default', conn_type='beeline', port="10000", + host='localhost', extra="{\"use_beeline\": true, \"auth\": \"\"}", + schema='default')) + merge_conn( + models.Connection( + conn_id='bigquery_default', conn_type='bigquery')) + merge_conn( + models.Connection( + conn_id='local_mysql', conn_type='mysql', + host='localhost', login='airflow', password='airflow', + schema='airflow')) + merge_conn( + models.Connection( + conn_id='presto_default', conn_type='presto', + host='localhost', + schema='hive', port=3400)) + merge_conn( + models.Connection( + conn_id='hive_cli_default', conn_type='hive_cli', + schema='default',)) + merge_conn( + models.Connection( + conn_id='hiveserver2_default', conn_type='hiveserver2', + host='localhost', + schema='default', port=10000)) + merge_conn( + models.Connection( + conn_id='metastore_default', conn_type='hive_metastore', + host='localhost', extra="{\"authMechanism\": \"PLAIN\"}", + port=9083)) + merge_conn( + models.Connection( + conn_id='mysql_default', conn_type='mysql', + login='root', + host='localhost')) + merge_conn( + models.Connection( + conn_id='postgres_default', conn_type='postgres', + login='postgres', + schema='airflow', + host='localhost')) + merge_conn( + models.Connection( + conn_id='sqlite_default', conn_type='sqlite', + host='/tmp/sqlite_default.db')) + merge_conn( + models.Connection( + conn_id='http_default', conn_type='http', + host='https://www.google.com/')) + merge_conn( + models.Connection( + conn_id='mssql_default', conn_type='mssql', + host='localhost', port=1433)) + merge_conn( + models.Connection( + conn_id='vertica_default', conn_type='vertica', + host='localhost', port=5433)) + merge_conn( + models.Connection( + conn_id='webhdfs_default', conn_type='hdfs', + host='localhost', port=50070)) + merge_conn( + models.Connection( + conn_id='ssh_default', conn_type='ssh', + host='localhost')) + + # Known event types + KET = models.KnownEventType + if not session.query(KET).filter(KET.know_event_type == 'Holiday').first(): + session.add(KET(know_event_type='Holiday')) + if not session.query(KET).filter(KET.know_event_type == 'Outage').first(): + session.add(KET(know_event_type='Outage')) + if not session.query(KET).filter( + KET.know_event_type == 'Natural Disaster').first(): + session.add(KET(know_event_type='Natural Disaster')) + if not session.query(KET).filter( + KET.know_event_type == 'Marketing Campaign').first(): + session.add(KET(know_event_type='Marketing Campaign')) + session.commit() + + models.DagBag(sync_to_db=True) + + Chart = models.Chart + chart_label = "Airflow task instance by type" + chart = session.query(Chart).filter(Chart.label == chart_label).first() + if not chart: + chart = Chart( + label=chart_label, + conn_id='airflow_db', + chart_type='bar', + x_is_date=False, + sql=( + "SELECT state, COUNT(1) as number " + "FROM task_instance " + "WHERE dag_id LIKE 'example%' " + "GROUP BY state"), + ) + session.add(chart) + + +def upgradedb(): + logging.info("Creating tables") + current_dir = os.path.dirname(os.path.abspath(__file__)) + package_dir = os.path.normpath(os.path.join(current_dir, '..')) + directory = os.path.join(package_dir, 'migrations') + config = Config(os.path.join(package_dir, 'alembic.ini')) + config.set_main_option('script_location', directory) + config.set_main_option('sqlalchemy.url', + configuration.get('core', 'SQL_ALCHEMY_CONN')) + command.upgrade(config, 'heads') + + +def resetdb(): + ''' + Clear out the database + ''' + from airflow import models + + logging.info("Dropping tables that exist") + models.Base.metadata.drop_all(settings.engine) + mc = MigrationContext.configure(settings.engine) + if mc._version.exists(settings.engine): + mc._version.drop(settings.engine) + initdb() diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py new file mode 100644 index 0000000000000..5568559699bb9 --- /dev/null +++ b/airflow/utils/decorators.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os + +from copy import copy +from functools import wraps + +from airflow.exceptions import AirflowException + + +def apply_defaults(func): + """ + Function decorator that Looks for an argument named "default_args", and + fills the unspecified arguments from it. + + Since python2.* isn't clear about which arguments are missing when + calling a function, and that this can be quite confusing with multi-level + inheritance and argument defaults, this decorator also alerts with + specific information about the missing arguments. + """ + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 1: + raise AirflowException( + "Use keyword arguments when initializing operators") + dag_args = {} + dag_params = {} + if 'dag' in kwargs and kwargs['dag']: + dag = kwargs['dag'] + dag_args = copy(dag.default_args) or {} + dag_params = copy(dag.params) or {} + + params = {} + if 'params' in kwargs: + params = kwargs['params'] + dag_params.update(params) + + default_args = {} + if 'default_args' in kwargs: + default_args = kwargs['default_args'] + if 'params' in default_args: + dag_params.update(default_args['params']) + del default_args['params'] + + dag_args.update(default_args) + default_args = dag_args + arg_spec = inspect.getargspec(func) + num_defaults = len(arg_spec.defaults) if arg_spec.defaults else 0 + non_optional_args = arg_spec.args[:-num_defaults] + if 'self' in non_optional_args: + non_optional_args.remove('self') + for arg in func.__code__.co_varnames: + if arg in default_args and arg not in kwargs: + kwargs[arg] = default_args[arg] + missing_args = list(set(non_optional_args) - set(kwargs)) + if missing_args: + msg = "Argument {0} is required".format(missing_args) + raise AirflowException(msg) + + kwargs['params'] = dag_params + + result = func(*args, **kwargs) + return result + return wrapper + +if 'BUILDING_AIRFLOW_DOCS' in os.environ: + # Monkey patch hook to get good function headers while building docs + apply_defaults = lambda x: x diff --git a/airflow/utils/email.py b/airflow/utils/email.py new file mode 100644 index 0000000000000..6877e4721cc1b --- /dev/null +++ b/airflow/utils/email.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from builtins import str +from past.builtins import basestring + +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import logging +import os +import smtplib + +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart +from email.mime.application import MIMEApplication +from email.utils import formatdate + +from airflow import configuration + + +def send_email(to, subject, html_content, files=None, dryrun=False): + """ + Send email using backend specified in EMAIL_BACKEND. + """ + path, attr = configuration.get('email', 'EMAIL_BACKEND').rsplit('.', 1) + module = importlib.import_module(path) + backend = getattr(module, attr) + return backend(to, subject, html_content, files=files, dryrun=dryrun) + + +def send_email_smtp(to, subject, html_content, files=None, dryrun=False): + """ + Send an email with html content + + >>> send_email('test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True) + """ + SMTP_MAIL_FROM = configuration.get('smtp', 'SMTP_MAIL_FROM') + + if isinstance(to, basestring): + if ',' in to: + to = to.split(',') + elif ';' in to: + to = to.split(';') + else: + to = [to] + + msg = MIMEMultipart('alternative') + msg['Subject'] = subject + msg['From'] = SMTP_MAIL_FROM + msg['To'] = ", ".join(to) + msg['Date'] = formatdate(localtime=True) + mime_text = MIMEText(html_content, 'html') + msg.attach(mime_text) + + for fname in files or []: + basename = os.path.basename(fname) + with open(fname, "rb") as f: + msg.attach(MIMEApplication( + f.read(), + Content_Disposition='attachment; filename="%s"' % basename, + Name=basename + )) + + send_MIME_email(SMTP_MAIL_FROM, to, msg, dryrun) + + +def send_MIME_email(e_from, e_to, mime_msg, dryrun=False): + SMTP_HOST = configuration.get('smtp', 'SMTP_HOST') + SMTP_PORT = configuration.getint('smtp', 'SMTP_PORT') + SMTP_USER = configuration.get('smtp', 'SMTP_USER') + SMTP_PASSWORD = configuration.get('smtp', 'SMTP_PASSWORD') + SMTP_STARTTLS = configuration.getboolean('smtp', 'SMTP_STARTTLS') + SMTP_SSL = configuration.getboolean('smtp', 'SMTP_SSL') + + if not dryrun: + s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else smtplib.SMTP(SMTP_HOST, SMTP_PORT) + if SMTP_STARTTLS: + s.starttls() + if SMTP_USER and SMTP_PASSWORD: + s.login(SMTP_USER, SMTP_PASSWORD) + logging.info("Sent an alert email to " + str(e_to)) + s.sendmail(e_from, e_to, mime_msg.as_string()) + s.quit() diff --git a/airflow/utils/file.py b/airflow/utils/file.py new file mode 100644 index 0000000000000..183c83433b84c --- /dev/null +++ b/airflow/utils/file.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import unicode_literals + +import errno +import shutil +from tempfile import mkdtemp + +from contextlib import contextmanager + + +@contextmanager +def TemporaryDirectory(suffix='', prefix=None, dir=None): + name = mkdtemp(suffix=suffix, prefix=prefix, dir=dir) + try: + yield name + finally: + try: + shutil.rmtree(name) + except OSError as e: + # ENOENT - no such file or directory + if e.errno != errno.ENOENT: + raise e diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py new file mode 100644 index 0000000000000..beccbe69b2d5c --- /dev/null +++ b/airflow/utils/helpers.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from builtins import input +from past.builtins import basestring +from datetime import datetime +import imp +import logging +import os +import re + +from airflow.exceptions import AirflowException + + +def validate_key(k, max_length=250): + if not isinstance(k, basestring): + raise TypeError("The key has to be a string") + elif len(k) > max_length: + raise AirflowException( + "The key has to be less than {0} characters".format(max_length)) + elif not re.match(r'^[A-Za-z0-9_\-\.]+$', k): + raise AirflowException( + "The key ({k}) has to be made of alphanumeric characters, dashes, " + "dots and underscores exclusively".format(**locals())) + else: + return True + + +def alchemy_to_dict(obj): + """ + Transforms a SQLAlchemy model instance into a dictionary + """ + if not obj: + return None + d = {} + for c in obj.__table__.columns: + value = getattr(obj, c.name) + if type(value) == datetime: + value = value.isoformat() + d[c.name] = value + return d + + +def ask_yesno(question): + yes = set(['yes', 'y']) + no = set(['no', 'n']) + + done = False + print(question) + while not done: + choice = input().lower() + if choice in yes: + return True + elif choice in no: + return False + else: + print("Please respond by yes or no.") + + +def import_module_attrs(parent_module_globals, module_attrs_dict): + ''' + Attempts to import a set of modules and specified attributes in the + form of a dictionary. The attributes are copied in the parent module's + namespace. The function returns a list of attributes names that can be + affected to __all__. + + This is used in the context of ``operators`` and ``hooks`` and + silence the import errors for when libraries are missing. It makes + for a clean package abstracting the underlying modules and only + brings functional operators to those namespaces. + ''' + imported_attrs = [] + for mod, attrs in list(module_attrs_dict.items()): + try: + path = os.path.realpath(parent_module_globals['__file__']) + folder = os.path.dirname(path) + f, filename, description = imp.find_module(mod, [folder]) + module = imp.load_module(mod, f, filename, description) + for attr in attrs: + parent_module_globals[attr] = getattr(module, attr) + imported_attrs += [attr] + except Exception as err: + logging.debug("Error importing module {mod}: {err}".format( + mod=mod, err=err)) + return imported_attrs + + +def is_in(obj, l): + """ + Checks whether an object is one of the item in the list. + This is different from ``in`` because ``in`` uses __cmp__ when + present. Here we change based on the object itself + """ + for item in l: + if item is obj: + return True + return False + + +def is_container(obj): + """ + Test if an object is a container (iterable) but not a string + """ + return hasattr(obj, '__iter__') and not isinstance(obj, basestring) + + +def as_tuple(obj): + """ + If obj is a container, returns obj as a tuple. + Otherwise, returns a tuple containing obj. + """ + if is_container(obj): + return tuple(obj) + else: + return tuple([obj]) + + +def chain(*tasks): + """ + Given a number of tasks, builds a dependency chain. + + chain(task_1, task_2, task_3, task_4) + + is equivalent to + + task_1.set_downstream(task_2) + task_2.set_downstream(task_3) + task_3.set_downstream(task_4) + """ + for up_task, down_task in zip(tasks[:-1], tasks[1:]): + up_task.set_downstream(down_task) diff --git a/airflow/utils/json.py b/airflow/utils/json.py new file mode 100644 index 0000000000000..b94033507eae4 --- /dev/null +++ b/airflow/utils/json.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from datetime import datetime, date +import json +import numpy as np + + +# Dates and JSON encoding/deconding + +def json_ser(obj): + """ + json serializer that deals with dates + usage: json.dumps(object, default=utils.json.json_ser) + """ + if isinstance(obj, (datetime, date)): + return obj.isoformat() + + +class AirflowJsonEncoder(json.JSONEncoder): + def default(self, obj): + # convert dates and numpy objects in a json serializable format + if isinstance(obj, datetime): + return obj.strftime('%Y-%m-%dT%H:%M:%SZ') + elif isinstance(obj, date): + return obj.strftime('%Y-%m-%d') + elif type(obj) in [np.int_, np.intc, np.intp, np.int8, np.int16, + np.int32, np.int64, np.uint8, np.uint16, + np.uint32, np.uint64]: + return int(obj) + elif type(obj) in [np.bool_]: + return bool(obj) + elif type(obj) in [np.float_, np.float16, np.float32, np.float64, + np.complex_, np.complex64, np.complex128]: + return float(obj) + + # Let the base class default method raise the TypeError + return json.JSONEncoder.default(self, obj) diff --git a/airflow/utils/logging.py b/airflow/utils/logging.py new file mode 100644 index 0000000000000..3d5d48ecdfbf5 --- /dev/null +++ b/airflow/utils/logging.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from builtins import object + +import logging + +from airflow import configuration + + +class LoggingMixin(object): + """ + Convenience super-class to have a logger configured with the class name + """ + + @property + def logger(self): + try: + return self._logger + except AttributeError: + self._logger = logging.root.getChild(self.__class__.__module__ + '.' + self.__class__.__name__) + return self._logger + + +class GCSLog(object): + """ + Utility class for reading and writing logs in GCS. + Requires either airflow[gcloud] or airflow[gcp_api] and + setting the REMOTE_BASE_LOG_FOLDER and REMOTE_LOG_CONN_ID configuration + options in airflow.cfg. + """ + def __init__(self): + """ + Attempt to create hook with airflow[gcloud] (and set + use_gcloud = True), otherwise uses airflow[gcp_api] + """ + remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID') + self.use_gcloud = False + + try: + from airflow.contrib.hooks import GCSHook + self.hook = GCSHook(remote_conn_id) + self.use_gcloud = True + except: + try: + from airflow.contrib.hooks import GoogleCloudStorageHook + self.hook = GoogleCloudStorageHook(remote_conn_id) + except: + self.hook = None + logging.error( + 'Could not create a GCSHook with connection id "{}". ' + 'Please make sure that either airflow[gcloud] or ' + 'airflow[gcp_api] is installed and the GCS connection ' + 'exists.'.format(remote_conn_id)) + + def read(self, remote_log_location, return_error=True): + """ + Returns the log found at the remote_log_location. + + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param return_error: if True, returns a string error message if an + error occurs. Otherwise returns '' when an error occurs. + :type return_error: bool + """ + if self.hook: + try: + if self.use_gcloud: + gcs_blob = self.hook.get_blob(remote_log_location) + if gcs_blob: + return gcs_blob.download_as_string().decode() + else: + bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1) + return self.hook.download(bkt, blob).decode() + except: + pass + + # raise/return error if we get here + err = 'Could not read logs from {}'.format(remote_log_location) + logging.error(err) + return err if return_error else '' + + def write(self, log, remote_log_location, append=False): + """ + Writes the log to the remote_log_location. Fails silently if no hook + was created. + + :param log: the log to write to the remote_log_location + :type log: string + :param remote_log_location: the log's location in remote storage + :type remote_log_location: string (path) + :param append: if False, any existing log file is overwritten. If True, + the new log is appended to any existing logs. + :type append: bool + + """ + if self.hook: + + if append: + old_log = self.read(remote_log_location) + log = old_log + '\n' + log + + try: + if self.use_gcloud: + self.hook.upload_from_string( + log, + blob=remote_log_location, + replace=True) + return + else: + bkt, blob = remote_log_location.lstrip('gs:/').split('/', 1) + from tempfile import NamedTemporaryFile + with NamedTemporaryFile(mode='w+') as tmpfile: + tmpfile.write(log) + self.hook.upload(bkt, blob, tmpfile.name) + return + except: + pass + + # raise/return error if we get here + logging.error('Could not write logs to {}'.format(remote_log_location)) diff --git a/airflow/utils/state.py b/airflow/utils/state.py new file mode 100644 index 0000000000000..e13d41095f710 --- /dev/null +++ b/airflow/utils/state.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import unicode_literals + +from builtins import object + + +class State(object): + """ + Static class with task instance states constants and color method to + avoid hardcoding. + """ + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + SHUTDOWN = "shutdown" # External request to shut down + FAILED = "failed" + UP_FOR_RETRY = "up_for_retry" + UPSTREAM_FAILED = "upstream_failed" + SKIPPED = "skipped" + + state_color = { + QUEUED: 'gray', + RUNNING: 'lime', + SUCCESS: 'green', + SHUTDOWN: 'blue', + FAILED: 'red', + UP_FOR_RETRY: 'gold', + UPSTREAM_FAILED: 'orange', + SKIPPED: 'pink', + } + + @classmethod + def color(cls, state): + if state in cls.state_color: + return cls.state_color[state] + else: + return 'white' + + @classmethod + def color_fg(cls, state): + color = cls.color(state) + if color in ['green', 'red']: + return 'white' + else: + return 'black' + + @classmethod + def runnable(cls): + return [ + None, cls.FAILED, cls.UP_FOR_RETRY, cls.UPSTREAM_FAILED, + cls.SKIPPED, cls.QUEUED] diff --git a/airflow/utils/timeout.py b/airflow/utils/timeout.py new file mode 100644 index 0000000000000..62af9db9b02eb --- /dev/null +++ b/airflow/utils/timeout.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import logging +import signal + +from builtins import object + +from airflow.exceptions import AirflowTaskTimeout + + +class timeout(object): + """ + To be used in a ``with`` block and timeout its content. + """ + def __init__(self, seconds=1, error_message='Timeout'): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + logging.error("Process timed out") + raise AirflowTaskTimeout(self.error_message) + + def __enter__(self): + try: + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + except ValueError as e: + logging.warning("timeout can't be used in the current context") + logging.exception(e) + + def __exit__(self, type, value, traceback): + try: + signal.alarm(0) + except ValueError as e: + logging.warning("timeout can't be used in the current context") + logging.exception(e) diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py new file mode 100644 index 0000000000000..f8309090f75ba --- /dev/null +++ b/airflow/utils/trigger_rule.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import unicode_literals + +from builtins import object + + +class TriggerRule(object): + ALL_SUCCESS = 'all_success' + ALL_FAILED = 'all_failed' + ALL_DONE = 'all_done' + ONE_SUCCESS = 'one_success' + ONE_FAILED = 'one_failed' + DUMMY = 'dummy' + + @classmethod + def is_valid(cls, trigger_rule): + return trigger_rule in cls.all_triggers() + + @classmethod + def all_triggers(cls): + return [getattr(cls, attr) + for attr in dir(cls) + if not attr.startswith("__") and not callable(getattr(cls, attr))] diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 03abf83a6ef3c..69dddf1faa5fa 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -8,14 +8,15 @@ import gzip import dateutil.parser as dateparser import json -import os from flask import after_this_request, request, Response from flask_login import current_user from jinja2 import Template import wtforms from wtforms.compat import text_type -from airflow import configuration, models, settings, utils +from airflow import configuration, models, settings +from airflow.utils.json import AirflowJsonEncoder +from airflow.utils.email import send_email AUTHENTICATE = configuration.getboolean('webserver', 'AUTHENTICATE') @@ -147,7 +148,7 @@ def wrapper(*args, **kwargs): ''').render(**locals()) if task.email: - utils.send_email(task.email, subject, content) + send_email(task.email, subject, content) """ return f(*args, **kwargs) return wrapper @@ -159,7 +160,7 @@ def json_response(obj): """ return Response( response=json.dumps( - obj, indent=4, cls=utils.AirflowJsonEncoder), + obj, indent=4, cls=AirflowJsonEncoder), status=200, mimetype="application/json") diff --git a/airflow/www/views.py b/airflow/www/views.py index cbb2db39ee42f..d7864ec746f3f 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -37,15 +37,18 @@ from pygments.formatters import HtmlFormatter import airflow -from airflow import models -from airflow.settings import Session from airflow import configuration as conf -from airflow import utils -from airflow.utils import AirflowException -from airflow.www import utils as wwwutils +from airflow import models from airflow import settings -from airflow.models import State +from airflow.exceptions import AirflowException +from airflow.settings import Session +from airflow.utils.json import json_ser +from airflow.utils.state import State +from airflow.utils.db import provide_session +from airflow.utils.helpers import alchemy_to_dict +from airflow.utils import logging as log_utils +from airflow.www import utils as wwwutils from airflow.www.forms import DateTimeForm, DateTimeWithNumRunsForm QUERY_LIMIT = 100000 @@ -637,7 +640,7 @@ def dag_details(self): ) return self.render( 'airflow/dag_details.html', - dag=dag, title=title, states=states, State=utils.State) + dag=dag, title=title, states=states, State=State) @current_app.errorhandler(404) def circles(self): @@ -646,7 +649,7 @@ def circles(self): @current_app.errorhandler(500) def show_traceback(self): - from airflow import ascii as ascii_ + from airflow.utils import asciiart as ascii_ return render_template( 'airflow/traceback.html', hostname=socket.gethostname(), @@ -803,11 +806,11 @@ def log(self): # S3 if remote_log.startswith('s3:/'): - log += utils.S3Log().read(remote_log, return_error=True) + log += log_utils.S3Log().read(remote_log, return_error=True) # GCS elif remote_log.startswith('gs:/'): - log += utils.GCSLog().read(remote_log, return_error=True) + log += log_utils.GCSLog().read(remote_log, return_error=True) # unsupported elif remote_log: @@ -1138,7 +1141,7 @@ def tree(self): .all() ) dag_runs = { - dr.execution_date: utils.alchemy_to_dict(dr) for dr in dag_runs} + dr.execution_date: alchemy_to_dict(dr) for dr in dag_runs} tis = dag.get_task_instances( session, start_date=min_date, end_date=base_date) @@ -1146,7 +1149,7 @@ def tree(self): max_date = max([ti.execution_date for ti in tis]) if dates else None task_instances = {} for ti in tis: - tid = utils.alchemy_to_dict(ti) + tid = alchemy_to_dict(ti) dr = dag_runs.get(ti.execution_date) tid['external_trigger'] = dr['external_trigger'] if dr else False task_instances[(ti.task_id, ti.execution_date)] = tid @@ -1201,7 +1204,7 @@ def recurse_nodes(task, visited): for d in dates], } - data = json.dumps(data, indent=4, default=utils.json_ser) + data = json.dumps(data, indent=4, default=json_ser) session.commit() session.close() @@ -1294,7 +1297,7 @@ class GraphForm(Form): data={'execution_date': dttm.isoformat(), 'arrange': arrange}) task_instances = { - ti.task_id: utils.alchemy_to_dict(ti) + ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(session, dttm, dttm)} tasks = { t.task_id: { @@ -1593,7 +1596,7 @@ def task_instances(self): return ("Error: Invalid execution_date") task_instances = { - ti.task_id: utils.alchemy_to_dict(ti) + ti.task_id: alchemy_to_dict(ti) for ti in dag.get_task_instances(session, dttm, dttm)} return json.dumps(task_instances) @@ -1979,7 +1982,7 @@ def action_set_failed(self, ids): def action_set_success(self, ids): self.set_dagrun_state(ids, State.SUCCESS) - @utils.provide_session + @provide_session def set_dagrun_state(self, ids, target_state, session=None): try: DR = models.DagRun @@ -2058,7 +2061,7 @@ def action_set_success(self, ids): def action_set_retry(self, ids): self.set_task_instance_state(ids, State.UP_FOR_RETRY) - @utils.provide_session + @provide_session def set_task_instance_state(self, ids, target_state, session=None): try: TI = models.TaskInstance diff --git a/tests/core.py b/tests/core.py index 5cc737dcf51fe..9a50d199c2488 100644 --- a/tests/core.py +++ b/tests/core.py @@ -2,7 +2,6 @@ import doctest import json -import logging import os import re import unittest @@ -22,14 +21,16 @@ from airflow.models import Variable configuration.test_mode() -from airflow import jobs, models, DAG, utils, operators, hooks, macros, settings +from airflow import jobs, models, DAG, operators, hooks, utils, macros, settings, exceptions from airflow.hooks import BaseHook from airflow.bin import cli from airflow.www import app as application from airflow.settings import Session -from airflow.utils import LoggingMixin, round_time +from airflow.utils.state import State +from airflow.utils.dates import round_time +from airflow.utils.logging import LoggingMixin from lxml import html -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException from airflow.configuration import AirflowConfigException from airflow.minihivecluster import MiniHiveCluster @@ -121,7 +122,7 @@ def test_schedule_dag_no_previous_runs(self): assert dag_run.execution_date == datetime(2015, 1, 2, 0, 0), ( 'dag_run.execution_date did not match expectation: {0}' .format(dag_run.execution_date)) - assert dag_run.state == models.State.RUNNING + assert dag_run.state == State.RUNNING assert dag_run.external_trigger == False def test_schedule_dag_fake_scheduled_previous(self): @@ -141,7 +142,7 @@ def test_schedule_dag_fake_scheduled_previous(self): dag_id=dag.dag_id, run_id=models.DagRun.id_for_date(DEFAULT_DATE), execution_date=DEFAULT_DATE, - state=utils.State.SUCCESS, + state=State.SUCCESS, external_trigger=True) settings.Session().add(trigger) settings.Session().commit() @@ -153,7 +154,7 @@ def test_schedule_dag_fake_scheduled_previous(self): assert dag_run.execution_date == DEFAULT_DATE + delta, ( 'dag_run.execution_date did not match expectation: {0}' .format(dag_run.execution_date)) - assert dag_run.state == models.State.RUNNING + assert dag_run.state == State.RUNNING assert dag_run.external_trigger == False def test_schedule_dag_once(self): @@ -239,7 +240,7 @@ def test_schedule_dag_no_end_date_up_to_today_only(self): dag_runs.append(dag_run) # Mark the DagRun as complete - dag_run.state = utils.State.SUCCESS + dag_run.state = State.SUCCESS session.merge(dag_run) session.commit() @@ -461,7 +462,7 @@ def test_timeout(self): python_callable=lambda: sleep(5), dag=self.dag) self.assertRaises( - utils.AirflowTaskTimeout, + exceptions.AirflowTaskTimeout, t.run, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, force=True) @@ -1255,7 +1256,7 @@ def test_poke(self): class ConnectionTest(unittest.TestCase): def setUp(self): configuration.test_mode() - utils.initdb() + utils.db.initdb() os.environ['AIRFLOW_CONN_TEST_URI'] = ( 'postgres://username:password@ec2.compute.com:5432/the_database') os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = ( @@ -1396,16 +1397,16 @@ class EmailTest(unittest.TestCase): def setUp(self): configuration.remove_option('email', 'EMAIL_BACKEND') - @mock.patch('airflow.utils.send_email_smtp') + @mock.patch('airflow.utils.email.send_email') def test_default_backend(self, mock_send_email): - res = utils.send_email('to', 'subject', 'content') - mock_send_email.assert_called_with('to', 'subject', 'content', files=None, dryrun=False) + res = utils.email.send_email('to', 'subject', 'content') + mock_send_email.assert_called_with('to', 'subject', 'content') assert res == mock_send_email.return_value - @mock.patch('airflow.utils.send_email_smtp') + @mock.patch('airflow.utils.email.send_email_smtp') def test_custom_backend(self, mock_send_email): configuration.set('email', 'EMAIL_BACKEND', 'tests.core.send_email_test') - utils.send_email('to', 'subject', 'content') + utils.email.send_email('to', 'subject', 'content') send_email_test.assert_called_with('to', 'subject', 'content', files=None, dryrun=False) assert not mock_send_email.called @@ -1414,12 +1415,12 @@ class EmailSmtpTest(unittest.TestCase): def setUp(self): configuration.set('smtp', 'SMTP_SSL', 'False') - @mock.patch('airflow.utils.send_MIME_email') + @mock.patch('airflow.utils.email.send_MIME_email') def test_send_smtp(self, mock_send_mime): attachment = tempfile.NamedTemporaryFile() attachment.write(b'attachment') attachment.seek(0) - utils.send_email_smtp('to', 'subject', 'content', files=[attachment.name]) + utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name]) assert mock_send_mime.called call_args = mock_send_mime.call_args[0] assert call_args[0] == configuration.get('smtp', 'SMTP_MAIL_FROM') @@ -1437,7 +1438,7 @@ def test_send_mime(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() msg = MIMEMultipart() - utils.send_MIME_email('from', 'to', msg, dryrun=False) + utils.email.send_MIME_email('from', 'to', msg, dryrun=False) mock_smtp.assert_called_with( configuration.get('smtp', 'SMTP_HOST'), configuration.getint('smtp', 'SMTP_PORT'), @@ -1456,7 +1457,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): configuration.set('smtp', 'SMTP_SSL', 'True') mock_smtp.return_value = mock.Mock() mock_smtp_ssl.return_value = mock.Mock() - utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False) + utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False) assert not mock_smtp.called mock_smtp_ssl.assert_called_with( configuration.get('smtp', 'SMTP_HOST'), @@ -1466,7 +1467,7 @@ def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl): @mock.patch('smtplib.SMTP_SSL') @mock.patch('smtplib.SMTP') def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): - utils.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True) + utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True) assert not mock_smtp.called assert not mock_smtp_ssl.called diff --git a/tests/operators/docker_operator.py b/tests/operators/docker_operator.py index 15dca9d7b98cf..4f9004c6292ea 100644 --- a/tests/operators/docker_operator.py +++ b/tests/operators/docker_operator.py @@ -3,7 +3,7 @@ from airflow.operators.docker_operator import DockerOperator from docker.client import Client -from airflow.utils import AirflowException +from airflow.exceptions import AirflowException try: from unittest import mock @@ -16,7 +16,7 @@ class DockerOperatorTestCase(unittest.TestCase): @unittest.skipIf(mock is None, 'mock package not present') - @mock.patch('airflow.utils.mkdtemp') + @mock.patch('airflow.utils.file.mkdtemp') @mock.patch('airflow.operators.docker_operator.Client') def test_execute(self, client_class_mock, mkdtemp_mock): host_config = mock.Mock()