diff --git a/airflow/providers/teradata/hooks/teradata.py b/airflow/providers/teradata/hooks/teradata.py index 5f5d1ed8fbd5c..ff58945591480 100644 --- a/airflow/providers/teradata/hooks/teradata.py +++ b/airflow/providers/teradata/hooks/teradata.py @@ -15,19 +15,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""An Airflow Hook for interacting with Teradata SQL Server.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import sqlalchemy -import teradatasql -from teradatasql import TeradataConnection +from typing import TypeVar from airflow.providers.common.sql.hooks.sql import DbApiHook -if TYPE_CHECKING: - from airflow.models.connection import Connection +T = TypeVar("T") class TeradataHook(DbApiHook): @@ -38,165 +32,17 @@ class TeradataHook(DbApiHook): Teradata DB Server URL, username, password and database name are fetched from the predefined connection config connection_id. It raises an airflow error if the given connection id doesn't exist. - You can also specify ssl parameters in the extra field of your connection - as ``{"sslmode": "require", "sslcert": "/path/to/cert.pem", etc}``. - - .. seealso:: - - :ref:`Teradata API connection ` + See :doc:` docs/apache-airflow-providers-teradata/connections/teradata.rst` for full documentation. :param args: passed to DbApiHook - :param database: The Teradata database to connect to. :param kwargs: passed to DbApiHook - """ - - # Override to provide the connection name. - conn_name_attr = "teradata_conn_id" - - # Override to have a default connection id for a particular dbHook - default_conn_name = "teradata_default" - - # Override if this db supports autocommit. - supports_autocommit = True - - # Override this for hook to have a custom name in the UI selection - conn_type = "teradata" - - # Override hook name to give descriptive name for hook - hook_name = "Teradata" - - # Override with the Teradata specific placeholder parameter string used for insert queries - placeholder: str = "?" - - # Override SQL query to be used for testing database connection - _test_connection_sql = "select 1" - - def __init__( - self, - *args, - database: str | None = None, - **kwargs, - ) -> None: - super().__init__(*args, schema=database, **kwargs) - def get_conn(self) -> TeradataConnection: - """Creates and returns a Teradata Connection object using teradatasql client. - Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id. + Usage Help: - :return: a Teradata connection object - """ - teradata_conn_config: dict = self._get_conn_config_teradatasql() - teradata_conn = teradatasql.connect(**teradata_conn_config) - return teradata_conn + >>> tdh = TeradataHook() + >>> sql = "SELECT top 1 _airbyte_ab_id from airbyte_td._airbyte_raw_Sales;" + >>> tdh.get_records(sql) + [[61ad1d63-3efd-4da4-9904-a4489cc3a520]] - def bulk_insert_rows( - self, - table: str, - rows: list[tuple], - target_fields: list[str] | None = None, - commit_every: int = 5000, - ): - """A bulk insert of records for Teradata SQL Database. - - This uses prepared statements via `executemany()`. For best performance, - pass in `rows` as an iterator. - - :param table: target Teradata database table, use dot notation to target a - specific database - :param rows: the rows to insert into the table - :param target_fields: the names of the columns to fill in the table, default None. - If None, each rows should have some order as table columns name - :param commit_every: the maximum number of rows to insert in one transaction - Default 5000. Set greater than 0. Set 1 to insert each row in each transaction - """ - if not rows: - raise ValueError("parameter rows could not be None or empty iterable") - conn = self.get_conn() - if self.supports_autocommit: - self.set_autocommit(conn, False) - cursor = conn.cursor() - cursor.fast_executemany = True - values_base = target_fields if target_fields else rows[0] - prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format( - tablename=table, - columns="({})".format(", ".join(target_fields)) if target_fields else "", - values=", ".join("?" for i in range(1, len(values_base) + 1)), - ) - row_count = 0 - # Chunk the rows - row_chunk = [] - for row in rows: - row_chunk.append(row) - row_count += 1 - if row_count % commit_every == 0: - cursor.executemany(prepared_stm, row_chunk) - conn.commit() # type: ignore[attr-defined] - # Empty chunk - row_chunk = [] - # Commit the leftover chunk - if len(row_chunk) > 0: - cursor.executemany(prepared_stm, row_chunk) - conn.commit() # type: ignore[attr-defined] - self.log.info("[%s] inserted %s rows", table, row_count) - cursor.close() - conn.close() # type: ignore[attr-defined] - - def _get_conn_config_teradatasql(self) -> dict[str, Any]: - """Returns set of config params required for connecting to Teradata DB using teradatasql client.""" - conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) - conn_config = { - "host": conn.host or "localhost", - "dbs_port": conn.port or "1025", - "database": conn.schema or "", - "user": conn.login or "dbc", - "password": conn.password or "dbc", - } - - if conn.extra_dejson.get("tmode", False): - conn_config["tmode"] = conn.extra_dejson["tmode"] - - # Handling SSL connection parameters - - if conn.extra_dejson.get("sslmode", False): - conn_config["sslmode"] = conn.extra_dejson["sslmode"] - if "verify" in conn_config["sslmode"]: - if conn.extra_dejson.get("sslca", False): - conn_config["sslca"] = conn.extra_dejson["sslca"] - if conn.extra_dejson.get("sslcapath", False): - conn_config["sslcapath"] = conn.extra_dejson["sslcapath"] - if conn.extra_dejson.get("sslcipher", False): - conn_config["sslcipher"] = conn.extra_dejson["sslcipher"] - if conn.extra_dejson.get("sslcrc", False): - conn_config["sslcrc"] = conn.extra_dejson["sslcrc"] - if conn.extra_dejson.get("sslprotocol", False): - conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"] - - return conn_config - - def get_sqlalchemy_engine(self, engine_kwargs=None): - """Returns a connection object using sqlalchemy.""" - conn: Connection = self.get_connection(getattr(self, self.conn_name_attr)) - link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}" - connection = sqlalchemy.create_engine(link) - return connection - - @staticmethod - def get_ui_field_behaviour() -> dict: - """Returns custom field behaviour.""" - import json - - return { - "hidden_fields": ["port"], - "relabeling": { - "host": "Database Server URL", - "schema": "Database Name", - "login": "Username", - }, - "placeholders": { - "extra": json.dumps( - {"tmode": "TERA", "sslmode": "verify-ca", "sslca": "/tmp/server-ca.pem"}, indent=4 - ), - "login": "dbc", - "password": "dbc", - }, - } + """