Skip to content

Commit

Permalink
Alternate approach to get location for Bigquery tables (#1449)
Browse files Browse the repository at this point in the history
An alternate approach to
#1416

Also closes #1460


The `region` parameter passed to Table wasn't used anywhere except
internally -- so going to remove it without deprecation
  • Loading branch information
kaxil authored Dec 19, 2022
1 parent 5c69b4d commit 91e4bef
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 103 deletions.
20 changes: 2 additions & 18 deletions python-sdk/src/astro/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def create_schema_and_table_if_needed(
is_file_pattern_based_schema_autodetection_supported = (
self.check_file_pattern_based_schema_autodetection_is_supported(source_file=file)
)
if self.schema_exists(table.metadata.schema) and if_exists == "replace":
if if_exists == "replace":
self.drop_table(table)
if use_native_support and is_schema_autodetection_supported and not file.is_pattern():
return
Expand Down Expand Up @@ -698,9 +698,7 @@ def export_table_to_file(
# Schema Management
# ---------------------------------------------------------

def create_schema_if_needed(
self, schema: str | None, location: str | None = None # skipcq: PYL-W0613
) -> None:
def create_schema_if_needed(self, schema: str | None) -> None:
"""
This function checks if the expected schema exists in the database. If the schema does not exist,
it will attempt to create it.
Expand All @@ -721,20 +719,6 @@ def schema_exists(self, schema: str) -> bool:
"""
raise NotImplementedError

def get_schema_region(self, schema: str | None = None) -> str: # skipcq: PYL-W0613, PYL-R0201
"""
Get region where the schema is created
:param schema: namespace
:return:
"""
return ""

def check_same_region(self, table: BaseTable, other_table: BaseTable): # skipcq: PYL-W0613, PYL-R0201
"""
Check if two tables are from the same database region
"""
return True

# ---------------------------------------------------------
# Context & Template Rendering methods (Transformations)
# ---------------------------------------------------------
Expand Down
57 changes: 44 additions & 13 deletions python-sdk/src/astro/databases/google/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,43 @@ def default_metadata(self) -> Metadata:
return Metadata(
schema=self.DEFAULT_SCHEMA,
database=self.hook.project_id,
region=BIGQUERY_SCHEMA_LOCATION,
) # type: ignore

def populate_table_metadata(self, table: BaseTable) -> BaseTable:
"""
Populate the metadata of the passed Table object from the Table used in instantiation of
the BigqueryDatabase or from the Default Metadata (passed in configs).
:param table: Table for which the metadata needs to be populated
:return: Modified Table
"""
if (
table.temp
and (self.table and not self.table.metadata.is_empty())
and (table.metadata and table.metadata.is_empty())
):
return self._populate_temp_table_metadata_from_input_table(table)
if table.metadata and table.metadata.is_empty() and self.default_metadata:
table.metadata = self.default_metadata
if not table.metadata.schema:
table.metadata.schema = self.DEFAULT_SCHEMA
return table

def _populate_temp_table_metadata_from_input_table(self, temp_table: BaseTable) -> BaseTable:
if not self.table:
return temp_table

source_location = self._get_schema_location(self.table.metadata.schema)
default_schema_location = self._get_schema_location(self.DEFAULT_SCHEMA)

if source_location == default_schema_location:
schema = self.DEFAULT_SCHEMA
else:
schema = f"{self.DEFAULT_SCHEMA}__{source_location.replace('-', '_')}"
source_db = self.table.metadata.database or self.hook.project_id
temp_table.metadata = Metadata(schema=schema, database=source_db)
return temp_table

def schema_exists(self, schema: str) -> bool:
"""
Checks if a dataset exists in the BigQuery
Expand All @@ -149,11 +183,11 @@ def schema_exists(self, schema: str) -> bool:
return False
return True

def get_schema_region(self, schema: str | None = None) -> str:
def _get_schema_location(self, schema: str | None = None) -> str:
"""
Get region where the schema is created
:param schema: Bigquery namespace
:return:
"""
if schema is None:
return ""
Expand All @@ -163,14 +197,6 @@ def get_schema_region(self, schema: str | None = None) -> str:
except GoogleNotFound:
return ""

def check_same_region(self, table: BaseTable, other_table: BaseTable):
"""
Check if two tables are from the same database region
"""
table_location = self.get_schema_region(schema=table.metadata.schema)
other_table_location = self.get_schema_region(schema=other_table.metadata.schema)
return table_location == other_table_location

@staticmethod
def get_merge_initialization_query(parameters: tuple) -> str:
"""
Expand Down Expand Up @@ -211,7 +237,7 @@ def load_pandas_dataframe_to_table(
credentials=creds,
)

def create_schema_if_needed(self, schema: str | None, location: str | None = None) -> None:
def create_schema_if_needed(self, schema: str | None) -> None:
"""
This function checks if the expected schema exists in the database. If the schema does not exist,
it will attempt to create it.
Expand All @@ -220,8 +246,12 @@ def create_schema_if_needed(self, schema: str | None, location: str | None = Non
"""
# We check if the schema exists first because BigQuery will fail on a create schema query even if it
# doesn't actually create a schema.
location = location or BIGQUERY_SCHEMA_LOCATION
if schema and not self.schema_exists(schema):

input_table_schema = self.table.metadata.schema if self.table and self.table.metadata else None
input_table_location = self._get_schema_location(input_table_schema)

location = input_table_location or BIGQUERY_SCHEMA_LOCATION
statement = self._create_schema_statement.format(schema, location)
self.run_sql(statement)

Expand Down Expand Up @@ -376,6 +406,7 @@ def load_gs_file_to_table(
if self.is_native_autodetect_schema_available(file=source_file):
load_job_config["autodetect"] = True # type: ignore

# TODO: Fix this -- it should be load_job_config.update(native_support_kwargs)
native_support_kwargs.update(native_support_kwargs)

# Since bigquery has other options besides used here, we need to expose them to end user.
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/src/astro/databases/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def populate_table_metadata(self, table: BaseTable) -> BaseTable:
table.conn_id = table.conn_id or self.conn_id
return table

def create_schema_if_needed(self, schema: str | None, location: str | None = None) -> None:
def create_schema_if_needed(self, schema: str | None) -> None:
"""
Since SQLite does not have schemas, we do not need to set a schema here.
"""
Expand Down
4 changes: 1 addition & 3 deletions python-sdk/src/astro/databricks/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def schema_exists(self, schema: str) -> bool:
# Schemas do not need to be created for delta, so we can assume this is true
return True

def create_schema_if_needed(
self, schema: str | None, location: str | None = None # skipcq: PYL-W0613
) -> None:
def create_schema_if_needed(self, schema: str | None) -> None: # skipcq: PYL-W0613
# Schemas do not need to be created for delta, so we don't need to do anything here
return None

Expand Down
23 changes: 1 addition & 22 deletions python-sdk/src/astro/sql/operators/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sqlalchemy.sql.functions import Function

from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator
from astro.utils.table import find_first_table
from astro.utils.typing_compat import Context


Expand Down Expand Up @@ -52,28 +51,8 @@ def __init__(
)

def execute(self, context: Context):
first_table = find_first_table(
op_args=self.op_args, # type: ignore
op_kwargs=self.op_kwargs,
python_callable=self.python_callable,
parameters=self.parameters, # type: ignore
context=context,
)

super().execute(context)

if (
first_table
and self.output_table.temp
and (not self.database_impl.check_same_region(table=first_table, other_table=self.output_table))
):
self.output_table.metadata.region = self.database_impl.get_schema_region(
schema=first_table.metadata.schema
)

self.database_impl.create_schema_if_needed(
self.output_table.metadata.schema, self.output_table.metadata.region
)
self.database_impl.create_schema_if_needed(self.output_table.metadata.schema)
self.database_impl.drop_table(self.output_table)
self.database_impl.create_table_from_select_statement(
statement=self.sql,
Expand Down
26 changes: 3 additions & 23 deletions python-sdk/src/astro/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@
TEMP_PREFIX = "_tmp"


def metadata_field_converter(val):
if isinstance(val, dict):
if "_schema" in val:
val["schema"] = val.pop("_schema")
return Metadata(**val)
return val


@define
class Metadata:
"""
Expand All @@ -34,25 +26,13 @@ class Metadata:
"""

# This property is used by several databases, including: Postgres, Snowflake and BigQuery ("namespace")
_schema: str | None = None
schema: str | None = None
database: str | None = None
region: str | None = None

def is_empty(self) -> bool:
"""Check if all the fields are None."""
return all(getattr(self, field_name) is None for field_name in fields_dict(self.__class__))

@property
def schema(self):
if self.region:
# We are replacing the `-` with `_` because for bigquery doesn't allow `-` in schema name
return f"{self._schema}__{self.region.replace('-', '_')}"
return self._schema

@schema.setter
def schema(self, value):
self._schema = value


@define(slots=False)
class BaseTable:
Expand All @@ -79,7 +59,7 @@ class BaseTable:
# Setting converter allows passing a dictionary to metadata arg
metadata: Metadata = field(
factory=Metadata,
converter=metadata_field_converter,
converter=lambda val: Metadata(**val) if isinstance(val, dict) else val,
)
columns: list[Column] = field(factory=list)
temp: bool = field(default=False)
Expand Down Expand Up @@ -224,7 +204,7 @@ class Table(BaseTable, Dataset):
:param columns: columns which define the database table schema.
"""

uri: str = field(init=False)
uri: str = field(init=False, eq=False)
extra: dict | None = field(init=False, factory=dict)

def __new__(cls, *args, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions python-sdk/tests/custom_backend/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
([1, 2, "astro"], ["1", "2", {"class": "string", "value": "astro"}]),
({"software": "airflow"}, {"software": {"class": "string", "value": "airflow"}}),
(pd.DataFrame(data={"col": [1, 2]}), {}),
(np.int(2022), "2022"),
(np.float(3.14), "3.14"),
(int(2022), "2022"),
(float(3.14), "3.14"),
(np.int_(2022), 2022),
(np.float_(3.14), 3.14),
(np.int_([3, 1, 4]), [3, 1, 4]),
("astro", {"class": "string", "value": "astro"}),
(deque(), deque([])),
Expand Down
100 changes: 99 additions & 1 deletion python-sdk/tests/databases/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Tests specific to the Sqlite Database implementation."""
import pathlib
from unittest import mock

import pytest
from google.cloud.bigquery_datatransfer_v1.types import (
StartManualTransferRunsResponse,
TransferConfig,
TransferRun,
)

from astro import settings
from astro.databases.google.bigquery import BigqueryDatabase, S3ToBigqueryDataTransfer
from astro.files import File
from astro.table import TEMP_PREFIX, Metadata, Table

DEFAULT_CONN_ID = "google_cloud_default"
CUSTOM_CONN_ID = "gcp_conn"
Expand Down Expand Up @@ -41,3 +44,98 @@ def test_get_run_id():
)
config.runs.append(run)
assert S3ToBigqueryDataTransfer.get_run_id(config) == "62d6a4df-0000-2fad-8752-d4f547e68ef4"


@pytest.mark.parametrize(
"source_table,input_table,returned_table,source_location",
[
# Table Metadata is empty and it is copied from Default Metadata
(
None,
Table(name="s1"),
Table(name="s1", metadata=Metadata(schema=settings.BIGQUERY_SCHEMA, database="test_project_id")),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
# Table Metadata just contains database/project_id and only schema is copied from Default Metadata
(
None,
Table(name="s1", metadata=Metadata(database="test_project_id_2")),
Table(
name="s1", metadata=Metadata(schema=settings.BIGQUERY_SCHEMA, database="test_project_id_2")
),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
# Table Metadata contains both schema & database/project_id and the table remains unchanged
(
None,
Table(name="s1", metadata=Metadata(schema="test_schema", database="test_project_id_2")),
Table(name="s1", metadata=Metadata(schema="test_schema", database="test_project_id_2")),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
# Table is temp and its Metadata is empty but the source_table is not None but its metadata is empty
# so the metadata is copied from the default metadata
(
Table(name="t1"),
Table(name=f"{TEMP_PREFIX}_xyz"),
Table(
name=f"{TEMP_PREFIX}_xyz",
metadata=Metadata(schema=settings.BIGQUERY_SCHEMA, database="test_project_id"),
),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
(
Table(name="t1", metadata=Metadata(schema="test_schema")),
Table(name=f"{TEMP_PREFIX}_xyz"),
Table(
name=f"{TEMP_PREFIX}_xyz",
metadata=Metadata(schema=settings.BIGQUERY_SCHEMA, database="test_project_id"),
),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
(
Table(name="t1", metadata=Metadata(schema="test_schema", database="test_project_id2")),
Table(name=f"{TEMP_PREFIX}_xyz"),
Table(
name=f"{TEMP_PREFIX}_xyz",
metadata=Metadata(schema=settings.BIGQUERY_SCHEMA, database="test_project_id2"),
),
settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION,
),
(
Table(name="t1", metadata=Metadata(schema="schema_in_eu_west2", database="test_project_id2")),
Table(name=f"{TEMP_PREFIX}_xyz"),
Table(
name=f"{TEMP_PREFIX}_xyz",
metadata=Metadata(
schema=f"{settings.BIGQUERY_SCHEMA}__europe_west2", database="test_project_id2"
),
),
"europe-west2",
),
(
Table(name="t1", metadata=Metadata(schema="schema_in_eu_west2")),
Table(name=f"{TEMP_PREFIX}_xyz"),
Table(
name=f"{TEMP_PREFIX}_xyz",
metadata=Metadata(
schema=f"{settings.BIGQUERY_SCHEMA}__europe_west2", database="test_project_id"
),
),
"europe-west2",
),
],
)
@mock.patch("astro.databases.google.bigquery.BigQueryHook")
def test_populate_table_metadata(mock_bq_hook, source_table, input_table, returned_table, source_location):
bq_hook = mock.MagicMock(project_id="test_project_id")
mock_bq_hook.return_value = bq_hook

def mock_get_dataset(dataset_id):
if dataset_id == settings.BIGQUERY_SCHEMA:
return mock.MagicMock(location=settings.DEFAULT_BIGQUERY_SCHEMA_LOCATION)
return mock.MagicMock(location=source_location)

mock_bq_hook.return_value.get_dataset.side_effect = mock_get_dataset

db = BigqueryDatabase(table=source_table, conn_id="test_conn")
assert db.populate_table_metadata(input_table) == returned_table
Loading

0 comments on commit 91e4bef

Please sign in to comment.