Skip to content

Commit

Permalink
D401 support in Teradata Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
satish-chinthanippu committed Feb 9, 2024
1 parent f56bede commit ed75e3e
Showing 1 changed file with 118 additions and 3 deletions.
121 changes: 118 additions & 3 deletions airflow/providers/teradata/hooks/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@
"""An Airflow Hook for interacting with Teradata SQL Server."""
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import sqlalchemy
import teradatasql
from teradatasql import TeradataConnection

from airflow.providers.common.sql.hooks.sql import DbApiHook

if TYPE_CHECKING:
from airflow.models.connection import Connection


class TeradataHook(DbApiHook):
"""General hook for interacting with Teradata SQL Database.
Expand All @@ -43,7 +47,6 @@ class TeradataHook(DbApiHook):
:param args: passed to DbApiHook
:param database: The Teradata database to connect to.
:param kwargs: passed to DbApiHook
"""

# Override to provide the connection name.
Expand Down Expand Up @@ -76,7 +79,7 @@ def __init__(
super().__init__(*args, schema=database, **kwargs)

def get_conn(self) -> TeradataConnection:
"""Creates and returns a Teradata Connection object using teradatasql client.
"""Create and return a Teradata Connection object using teradatasql client.
Establishes connection to a Teradata SQL database using config corresponding to teradata_conn_id.
Expand All @@ -85,3 +88,115 @@ def get_conn(self) -> TeradataConnection:
teradata_conn_config: dict = self._get_conn_config_teradatasql()
teradata_conn = teradatasql.connect(**teradata_conn_config)
return teradata_conn

def bulk_insert_rows(
self,
table: str,
rows: list[tuple],
target_fields: list[str] | None = None,
commit_every: int = 5000,
):
"""Insert bulk of records into Teradata SQL Database.
This uses prepared statements via `executemany()`. For best performance,
pass in `rows` as an iterator.
:param table: target Teradata database table, use dot notation to target a
specific database
:param rows: the rows to insert into the table
:param target_fields: the names of the columns to fill in the table, default None.
If None, each rows should have some order as table columns name
:param commit_every: the maximum number of rows to insert in one transaction
Default 5000. Set greater than 0. Set 1 to insert each row in each transaction
"""
if not rows:
raise ValueError("parameter rows could not be None or empty iterable")
conn = self.get_conn()
if self.supports_autocommit:
self.set_autocommit(conn, False)
cursor = conn.cursor()
cursor.fast_executemany = True
values_base = target_fields if target_fields else rows[0]
prepared_stm = "INSERT INTO {tablename} {columns} VALUES ({values})".format(
tablename=table,
columns="({})".format(", ".join(target_fields)) if target_fields else "",
values=", ".join("?" for i in range(1, len(values_base) + 1)),
)
row_count = 0
# Chunk the rows
row_chunk = []
for row in rows:
row_chunk.append(row)
row_count += 1
if row_count % commit_every == 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
# Empty chunk
row_chunk = []
# Commit the leftover chunk
if len(row_chunk) > 0:
cursor.executemany(prepared_stm, row_chunk)
conn.commit() # type: ignore[attr-defined]
self.log.info("[%s] inserted %s rows", table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]

def _get_conn_config_teradatasql(self) -> dict[str, Any]:
"""Return set of config params required for connecting to Teradata DB using teradatasql client."""
conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
conn_config = {
"host": conn.host or "localhost",
"dbs_port": conn.port or "1025",
"database": conn.schema or "",
"user": conn.login or "dbc",
"password": conn.password or "dbc",
}

if conn.extra_dejson.get("tmode", False):
conn_config["tmode"] = conn.extra_dejson["tmode"]

# Handling SSL connection parameters

if conn.extra_dejson.get("sslmode", False):
conn_config["sslmode"] = conn.extra_dejson["sslmode"]
if "verify" in conn_config["sslmode"]:
if conn.extra_dejson.get("sslca", False):
conn_config["sslca"] = conn.extra_dejson["sslca"]
if conn.extra_dejson.get("sslcapath", False):
conn_config["sslcapath"] = conn.extra_dejson["sslcapath"]
if conn.extra_dejson.get("sslcipher", False):
conn_config["sslcipher"] = conn.extra_dejson["sslcipher"]
if conn.extra_dejson.get("sslcrc", False):
conn_config["sslcrc"] = conn.extra_dejson["sslcrc"]
if conn.extra_dejson.get("sslprotocol", False):
conn_config["sslprotocol"] = conn.extra_dejson["sslprotocol"]

return conn_config

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""Returns a connection object using sqlalchemy."""
conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
link = f"teradatasql://{conn.login}:{conn.password}@{conn.host}"
connection = sqlalchemy.create_engine(link)
return connection

@staticmethod
def get_ui_field_behaviour() -> dict:
"""Return custom field behaviour."""
import json

return {
"hidden_fields": ["port"],
"relabeling": {
"host": "Database Server URL",
"schema": "Database Name",
"login": "Username",
},
"placeholders": {
"extra": json.dumps(
{"tmode": "TERA", "sslmode": "verify-ca", "sslca": "/tmp/server-ca.pem"}, indent=4
),
"login": "dbc",
"password": "dbc",
},
}

0 comments on commit ed75e3e

Please sign in to comment.