From 61472744d5a1f40c385e3023640ca9aa1b810b4c Mon Sep 17 00:00:00 2001 From: orellabac Date: Wed, 10 Jul 2024 23:50:51 -0600 Subject: [PATCH] reader helper --- extras/jdbc_read/jdbc_reader.py | 124 ++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 extras/jdbc_read/jdbc_reader.py diff --git a/extras/jdbc_read/jdbc_reader.py b/extras/jdbc_read/jdbc_reader.py new file mode 100644 index 0000000..8b2042e --- /dev/null +++ b/extras/jdbc_read/jdbc_reader.py @@ -0,0 +1,124 @@ +from snowflake.snowpark import Session, DataFrameReader +from snowflake.snowpark.context import get_active_session +from snowflake.snowpark.functions import col, lit, object_construct + +if not hasattr(DataFrameReader, "__jdbc_reader__"): + setattr(DataFrameReader, "__jdbc_reader__", True) + class JdbcDataFrameReader: + def __init__(self): + self.options = {} + def option(self,key:str,value:str): + self.options[lit(key)] = lit(value) + return self + def query(self,sql:str): + self.query_stmt = lit(sql) + return self + def load(self): + session = get_active_session() + jdbc_options = object_construct(*[item for pair in self.options.items() for item in pair]) + return session.table_function("READ_JDBC",jdbc_options,self.query_stmt) + def format(self,format_name): + return JdbcDataFrameReader() if format_name == "jdbc" else Exception("not supported") + def register_jdbc_reader(jdbc_drivers_stage:str, integration_name:str,secrets:str=None): + from snowflake.snowpark import Session + from snowflake.snowpark.functions import col + jdbc_reader_template = """ +CREATE OR REPLACE FUNCTION READ_JDBC(OPTION OBJECT, query STRING) + RETURNS TABLE(data OBJECT) + LANGUAGE JAVA + RUNTIME_VERSION = '11' + IMPORTS = (@@IMPORTS@@) + EXTERNAL_ACCESS_INTEGRATIONS = (@@EXTERNAL_ACCESS_INTEGRATIONS@@) + @@SECRETS@@ + HANDLER = 'JdbcDataReader' +AS $$ +import java.sql.*; +import java.util.*; +import java.util.stream.Stream; +import com.snowflake.snowpark_java.types.SnowflakeSecrets; +public class JdbcDataReader { + public static class OutputRow { + public Map data; + public OutputRow(Map data) { + this.data = data; + } + } + public static Class getOutputClass() { + return OutputRow.class; + } + public Stream process(Map jdbcConfig, String query) { + String jdbcUrl = jdbcConfig.get("url"); + String username; + String password; + + if ("true".equals(jdbcConfig.get("use_secrets"))) + { + SnowflakeSecrets sfSecrets = SnowflakeSecrets.newInstance(); + var secret = sfSecrets.getUsernamePassword("cred"); + username = secret.getUsername(); + password = secret.getPassword(); + } + else + { + username = jdbcConfig.get("username"); + password = jdbcConfig.get("password"); + } + try { + // Load the JDBC driver + Class.forName(jdbcConfig.get("driver")); + // Create a connection to the database + Connection connection = DriverManager.getConnection(jdbcUrl, username, password); + // Create a statement for executing SQL queries + Statement statement = connection.createStatement(); + // Execute the query + ResultSet resultSet = statement.executeQuery(query); + // Get metadata about the result set + ResultSetMetaData metaData = resultSet.getMetaData(); + // Create a list of column names + List columnNames = new ArrayList<>(); + int columnCount = metaData.getColumnCount(); + for (int i = 1; i <= columnCount; i++) { + columnNames.add(metaData.getColumnName(i)); + } + // Convert the ResultSet to a Stream of OutputRow objects + Stream resultStream = Stream.generate(() -> { + try { + if (resultSet.next()) { + Map rowMap = new HashMap<>(); + for (String columnName : columnNames) { + String columnValue = resultSet.getString(columnName); + rowMap.put(columnName, columnValue); + } + return new OutputRow(rowMap); + } else { + // Close resources + resultSet.close(); + statement.close(); + connection.close(); + return null; + } + } catch (SQLException e) { + e.printStackTrace(); + return null; + } + }).takeWhile(Objects::nonNull); + return resultStream; + } catch (Exception e) { + e.printStackTrace(); + Map rowMap = new HashMap<>(); + rowMap.put("ERROR",e.toString()); + return Stream.of(new OutputRow(rowMap)); + } + } +} +$$; +""" + session = Session.builder.getOrCreate() + jars = [f"'@{x[0]}'" for x in session.sql("list @%s" % jdbc_drivers_stage).select(col('"name"')).collect() if x[0].endswith(".jar")] + secrets_parts = f"SECRETS = ('cred' = {secrets} )" if secrets else "" + imports = ",".join(jars) + jdbc_reader_template=jdbc_reader_template.replace("@@IMPORTS@@", imports).replace("@@SECRETS@@",secrets_parts).replace("@@EXTERNAL_ACCESS_INTEGRATIONS@@",integration_name) + session.sql(jdbc_reader_template).show() + return "jdbc reader registered" + + DataFrameReader.format = format \ No newline at end of file