Skip to content

Commit

Permalink
Override GoogleBaseHook with BigqueryHook (#1442)
Browse files Browse the repository at this point in the history
# Description
## What is the current behavior?
Because of below issue:
```
airflow.exceptions.AirflowException: You are trying to use `common-sql` with GoogleBaseHook, but its provider does not support it. Please upgrade the provider to a version that supports `common-sql`. The hook class should be a subclass of `airflow.providers.common.sql.hooks.sql.DbApiHook`. Got GoogleBaseHook Hook with class hierarchy: [<class 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook'>, <class 'airflow.hooks.base.BaseHook'>, <class 'airflow.utils.log.logging_mixin.LoggingMixin'>, <class 'object'>]
```
We are using a work around and using Bigquey Hook
  • Loading branch information
utkarsharma2 authored Dec 20, 2022
1 parent b3c361e commit c8ee4b4
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 11 deletions.
4 changes: 4 additions & 0 deletions python-sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def database_table_fixture(request):
params = deepcopy(request.param)

database_name = params["database"]
user_table = params.get("table", None)
conn_id = DATABASE_NAME_TO_CONN_ID[database_name]
if user_table and user_table.conn_id:
conn_id = user_table.conn_id

database = create_database(conn_id)
table = params.get("table", Table(conn_id=database.conn_id, metadata=database.default_metadata))
if not isinstance(table, TempTable):
Expand Down
7 changes: 2 additions & 5 deletions python-sdk/src/astro/databases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
SUPPORTED_DATABASES = set(DEFAULT_CONN_TYPE_TO_MODULE_PATH.keys())


def create_database(
conn_id: str,
table: BaseTable | None = None,
) -> BaseDatabase:
def create_database(conn_id: str, table: BaseTable | None = None, region: str | None = None) -> BaseDatabase:
"""
Given a conn_id, return the associated Database class.
Expand All @@ -40,5 +37,5 @@ def create_database(
module_path = CONN_TYPE_TO_MODULE_PATH[conn_type]
module = importlib.import_module(module_path)
class_name = get_class_name(module_ref=module, suffix="Database")
database: BaseDatabase = getattr(module, class_name)(conn_id, table)
database: BaseDatabase = getattr(module, class_name)(conn_id, table, region)
return database
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class BaseDatabase(ABC):
NATIVE_AUTODETECT_SCHEMA_CONFIG: Mapping[FileLocation, Mapping[str, list[FileType] | Callable]] = {}
FILE_PATTERN_BASED_AUTODETECT_SCHEMA_SUPPORTED: set[FileLocation] = set()

def __init__(self, conn_id: str):
def __init__(self, conn_id: str, table: BaseTable | None = None, region: str | None = None):
self.conn_id = conn_id
self.sql: str | ClauseElement = ""

Expand Down
7 changes: 5 additions & 2 deletions python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,12 @@ class BigqueryDatabase(BaseDatabase):

_create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')"

def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None):
def __init__(
self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None
):
super().__init__(conn_id)
self.table = table
self.region = region

@property
def sql_type(self) -> str:
Expand All @@ -115,7 +118,7 @@ def sql_type(self) -> str:
@property
def hook(self) -> BigQueryHook:
"""Retrieve Airflow hook to interface with the BigQuery database."""
return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False)
return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=self.region)

@property
def sqlalchemy_engine(self) -> Engine:
Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ class PostgresDatabase(BaseDatabase):
illegal_column_name_chars: list[str] = ["."]
illegal_column_name_chars_replacement: list[str] = ["_"]

def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None):
def __init__(
self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None
):
super().__init__(conn_id)
self.table = table

Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ class SnowflakeDatabase(BaseDatabase):
)
DEFAULT_SCHEMA = SNOWFLAKE_SCHEMA

def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None):
def __init__(
self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None
):
super().__init__(conn_id)
self.table = table

Expand Down
4 changes: 3 additions & 1 deletion python-sdk/src/astro/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class SqliteDatabase(BaseDatabase):
logic in other parts of our code-base.
"""

def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None):
def __init__(
self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None
):
super().__init__(conn_id)
self.table = table

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pandas
from airflow import AirflowException
from airflow.decorators.base import get_unique_task_id
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator

from astro.databases import create_database
from astro.settings import BIGQUERY_SCHEMA_LOCATION
from astro.table import BaseTable
from astro.utils.typing_compat import Context

Expand Down Expand Up @@ -77,8 +79,27 @@ def __init__(
task_id=task_id if task_id is not None else get_unique_task_id("column_check"),
)

def get_db_hook(self) -> DbApiHook:
"""
Get the database hook for the connection.
:return: the database hook object.
"""
db = create_database(
conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION
)
if db.sql_type == "bigquery":
return db.hook
return super().get_db_hook()

def execute(self, context: "Context"):
if isinstance(self.dataset, BaseTable):
# Work around for GoogleBaseHook not inheriting from DBApi
# db = create_database(
# conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION
# )
# if db.sql_type == "bigquery":
# self._hook = db.hook
return super().execute(context=context)
elif type(self.dataset) == pandas.DataFrame:
self.df = self.dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astro import sql as aql
from astro.constants import Database
from astro.files import File
from astro.table import Table
from tests.sql.operators import utils as test_utils

CWD = pathlib.Path(__file__).parent
Expand All @@ -20,6 +21,7 @@
{
"database": Database.BIGQUERY,
"file": File(path=str(CWD) + "/../../../data/data_validation.csv"),
"table": Table(conn_id="bigquery"),
},
{
"database": Database.POSTGRES,
Expand All @@ -43,6 +45,7 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix
all the database we support.
"""
db, test_table = database_table_fixture
test_table.conn_id = "gcp_conn_project"
with sample_dag:
aql.ColumnCheckOperator(
dataset=test_table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astro import sql as aql
from astro.constants import Database
from astro.files import File
from astro.table import Table
from tests.sql.operators import utils as test_utils

CWD = pathlib.Path(__file__).parent
Expand All @@ -20,6 +21,7 @@
{
"database": Database.BIGQUERY,
"file": File(path=str(CWD) + "/../../../data/homes_main.csv"),
"table": Table(conn_id="bigquery"),
},
{
"database": Database.POSTGRES,
Expand Down

0 comments on commit c8ee4b4

Please sign in to comment.