Skip to content

Commit

Permalink
Merge pull request #92 from oracle/dev/async_python_models
Browse files Browse the repository at this point in the history
Optimized Python Models
  • Loading branch information
aosingh authored Jul 11, 2023
2 parents 2c38043 + 6c6cd59 commit 45c7a27
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/oracle-xe-adapter-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
- name: Install dbt-oracle with core dependencies
run: |
python -m pip install --upgrade pip
pip install pytest dbt-tests-adapter==1.5.1
pip install pytest dbt-tests-adapter==1.5.2
pip install -r requirements.txt
pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/oracle/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
version = "1.5.1"
version = "1.5.2"
6 changes: 3 additions & 3 deletions dbt/adapters/oracle/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class OracleAdapterCredentials(Credentials):
retry_count: Optional[int] = 1
retry_delay: Optional[int] = 3

# Fetch an auth token to run Python UDF
oml_auth_token_uri: Optional[str] = None
# Base URL for ADB-S OML REST API
oml_cloud_service_url: Optional[str] = None


_ALIASES = {
Expand All @@ -136,7 +136,7 @@ def _connection_keys(self) -> Tuple[str]:
'service', 'connection_string',
'shardingkey', 'supershardingkey',
'cclass', 'purity', 'retry_count',
'retry_delay', 'oml_auth_token_uri'
'retry_delay', 'oml_cloud_service_url'
)

@classmethod
Expand Down
56 changes: 18 additions & 38 deletions dbt/adapters/oracle/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from dbt.utils import filter_null_values

from dbt.adapters.oracle.keyword_catalog import KEYWORDS
from dbt.adapters.oracle.python_submissions import OracleADBSPythonJob
from dbt.adapters.oracle.connections import AdapterResponse

logger = AdapterLogger("oracle")

Expand Down Expand Up @@ -367,49 +369,27 @@ def get_oml_auth_token(self) -> str:

def submit_python_job(self, parsed_model: dict, compiled_code: str):
"""Submit user defined Python function
The function pyqEval when used in Oracle Autonomous Database,
calls a user-defined Python function.
pyqEval(PAR_LST, OUT_FMT, SRC_NAME, SRC_OWNER, ENV_NAME)
- PAR_LST -> Parameter List
- OUT_FMT -> JSON clob of the columns
- ENV_NAME -> Name of conda environment
https://docs.oracle.com/en/database/oracle/machine-learning/oml4py/1/mlepe/op-py-scripts-v1-do-eval-scriptname-post.html
"""
identifier = parsed_model["alias"]
oml_oauth_access_token = self.get_oml_auth_token()
py_q_script_name = f"{identifier}_dbt_py_script"
py_q_eval_output_fmt = '{"result":"number"}'
py_q_eval_result_table = f"o$pt_dbt_pyqeval_{identifier}_tmp_{datetime.datetime.utcnow().strftime('%H%M%S')}"

conda_env_name = parsed_model["config"].get("conda_env_name")
if conda_env_name:
logger.info("Custom python environment is %s", conda_env_name)
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
out_fmt => ''{py_q_eval_output_fmt}'',
scr_name => ''{py_q_script_name}'',
scr_owner => NULL,
env_name => ''{conda_env_name}''))"""
else:
py_q_eval_sql = f"""CREATE GLOBAL TEMPORARY TABLE {py_q_eval_result_table}
AS SELECT * FROM TABLE(pyqEval(par_lst => NULL,
out_fmt => ''{py_q_eval_output_fmt}'',
scr_name => ''{py_q_script_name}'',
scr_owner => NULL))"""

py_exec_main_sql = f"""
BEGIN
sys.pyqSetAuthToken('{oml_oauth_access_token}');
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
EXECUTE IMMEDIATE '{py_q_eval_sql}';
EXECUTE IMMEDIATE 'DROP TABLE {py_q_eval_result_table}';
sys.pyqScriptDrop('{py_q_script_name}');
END;
py_q_create_script = f"""
BEGIN
sys.pyqScriptCreate('{py_q_script_name}', '{compiled_code.strip()}', FALSE, TRUE);
END;
"""
response, _ = self.execute(sql=py_exec_main_sql)
response, _ = self.execute(sql=py_q_create_script)
python_job = OracleADBSPythonJob(parsed_model=parsed_model,
credential=self.config.credentials)
python_job()
py_q_drop_script = f"""
BEGIN
sys.pyqScriptDrop('{py_q_script_name}');
END;
"""

response, _ = self.execute(sql=py_q_drop_script)
logger.info(response)
return response
225 changes: 225 additions & 0 deletions dbt/adapters/oracle/python_submissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
Copyright (c) 2023, Oracle and/or its affiliates.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import datetime
import http
import json
from typing import Dict

import requests
import time

import dbt.exceptions
from dbt.adapters.oracle import OracleAdapterCredentials
from dbt.events import AdapterLogger
from dbt.ui import red, green

# ADB-S OML Rest API minimum timeout is 1800 seconds
DEFAULT_TIMEOUT_IN_SECONDS = 1800
DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS = 2

OMLUSERS_OAUTH_API = "/omlusers/api/oauth2/v1/token"
OML_DO_EVAL_API = "/oml/api/py-scripts/v1/do-eval/{script_name}"

logger = AdapterLogger("oracle")


class OracleOML4PYClient:

def __init__(self, oml_cloud_service_url, username, password):
self.base_url = oml_cloud_service_url
self._username = username
self._password = password
self.token = None
self.token_expires_at = None
self.token_url = self.base_url + OMLUSERS_OAUTH_API
self._session = requests.Session()

@property
def session(self):
return self._session

def get_token(self):
"""Get access_token or refresh_token"""
# If access token is about to expire then refresh the token
if self.token_expires_at and self.token_expires_at - datetime.datetime.utcnow() < datetime.timedelta(minutes=1):
return self._get_token(grant_type="refresh_token")
elif self.token: # Token is valid
return self.token
else: # Generate a new token
return self._get_token(grant_type="password")

def _get_token(self, grant_type="password"):
"""Gets access_token or refresh_token using /broker/pdbcs/private/v1/token"""
data = {"grant_type": grant_type}
if grant_type == "password":
data["username"] = self._username
data["password"] = self._password
else:
data["token"] = self.token

r = self.session.post(
url=self.token_url,
json=data,
headers={
"Accept": "application/json",
"Content-type": "application/json",
},
)
r.raise_for_status()
response = r.json()
self.token = response["accessToken"]
self.token_expires_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=response["expiresIn"])
return self.token

@property
def default_headers(self):
"""Default headers added to every request"""
return {
"Content-type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {self.get_token()}",
}

def request(self, method: str, path: str,
raise_for_status: bool = False,
**kwargs) -> requests.Response:
"""
Description:
Perform a desired action (GET, PUT, POST) on a certain resource
Args:
method (str) -> HTTP verb like GET, PUT, POST, etc
path (str) -> path to the resource e.g. /job/{job_id}
raise_for_status (bool) -> True if HTTPError should be raised in case of an error else False
Returns:
object of type request.Response
Raises:
requests.HTTPError() in case of en error, if raise_for_status is True
"""
url = path if path.startswith(self.base_url) else self.base_url + path
self.session.headers.update(self.default_headers)
r = self.session.request(method=method, url=url, **kwargs)
try:
r.raise_for_status()
except requests.HTTPError:
if raise_for_status:
raise
return r


class OracleADBSPythonJob:
"""Callable to submit Python Script to ADB-S
"""

def __init__(self,
parsed_model: Dict,
credential: OracleAdapterCredentials) -> None:
self.identifier = parsed_model["alias"]
self.py_q_script_name = f"{self.identifier}_dbt_py_script"
self.conda_env_name = parsed_model["config"].get("conda_env_name")
self.timeout = parsed_model["config"].get("timeout", DEFAULT_TIMEOUT_IN_SECONDS)
self.async_flag = parsed_model["config"].get("async_flag", False)
self.service = parsed_model["config"].get("service", "HIGH")
self.oml4py_client = OracleOML4PYClient(oml_cloud_service_url=credential.oml_cloud_service_url,
username=credential.user,
password=credential.password)

def schedule_async_job_and_wait_for_completion(self, data):
logger.info(f"Running Python aysnc job using {data}")
try:
r = self.oml4py_client.request(method="POST",
path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name),
data=json.dumps(data),
raise_for_status=False)
if r.status_code in (http.HTTPStatus.BAD_REQUEST, http.HTTPStatus.INTERNAL_SERVER_ERROR):
logger.error(red(r.json()))
r.raise_for_status()
except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} scheduling async Python job for model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error scheduling Python model {self.identifier}")

job_location = r.headers["location"]
logger.info(f"Started async job {job_location}")
start_time = time.time()

while time.time() - start_time < self.timeout:
logger.debug(f"Checking Job status for : {job_location}")
try:
job_status = self.oml4py_client.request(method="GET",
path=job_location,
raise_for_status=False)
job_status_code = job_status.status_code
logger.debug(f"Job status code is: {job_status_code}")
if job_status_code == http.HTTPStatus.FOUND:
logger.info(green(f"Job {job_location} completed"))
job_result = self.oml4py_client.request(method="GET",
path=f"{job_location}/result",
raise_for_status=False)
job_result_json = job_result.json()
if 'errorMessage' in job_result_json:
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result_json}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
job_result.raise_for_status()
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result_json}"))
return
elif job_status_code == http.HTTPStatus.INTERNAL_SERVER_ERROR:
logger.error(red(f"FAILURE - Job status is: {job_status.json()}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
else:
logger.debug(f"Python model {self.identifier} job status is: {job_status.json()}")
job_status.raise_for_status()

except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} checking status of Python job {job_location} for model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error checking status for job {job_location}")

time.sleep(DEFAULT_DELAY_BETWEEN_POLL_IN_SECONDS)
logger.error(red(f"Timeout error for Python model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Timeout error for Python model {self.identifier}")

def __call__(self, *args, **kwargs):
data = {
"service": self.service
}
if self.async_flag:
data["asyncFlag"] = self.async_flag
data["timeout"] = self.timeout
if self.conda_env_name:
data["envName"] = self.conda_env_name

if self.async_flag:
self.schedule_async_job_and_wait_for_completion(data=data)
else: # Run in blocking mode
logger.info(f"Running Python model {self.identifier} with args {data}")
try:
r = self.oml4py_client.request(method="POST",
path=OML_DO_EVAL_API.format(script_name=self.py_q_script_name),
data=json.dumps(data),
raise_for_status=False)
job_result = r.json()
if 'errorMessage' in job_result:
logger.error(red(f"FAILURE - Python model {self.identifier} Job failure is: {job_result}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")
r.raise_for_status()
logger.info(green(f"SUCCESS - Python model {self.identifier} Job result is: {job_result}"))
except requests.exceptions.RequestException as e:
logger.error(red(f"Error {e} running Python model {self.identifier}"))
raise dbt.exceptions.DbtRuntimeError(f"Error running Python model {self.identifier}")

2 changes: 1 addition & 1 deletion dbt/include/oracle/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
{{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }}
{% endmacro %}

{% macro oracle__get_empty_subquery_sql(select_sql) %}
{% macro oracle__get_empty_subquery_sql(select_sql, select_sql_header=none) %}
select * from (
{{ select_sql }}
) dbt_sbq_tmp
Expand Down
3 changes: 3 additions & 0 deletions dbt_adbs_test_project/models/test_py_ref.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
def model(dbt, session):
# Must be either table or incremental (view is not currently supported)
dbt.config(materialized="table")
dbt.config(async_flag=True)
dbt.config(timeout=900) # In seconds
dbt.config(service="HIGH") # LOW, MEDIUM, HIGH
# oml.core.DataFrame representing a datasource
s_df = dbt.ref("sales_cost")
return s_df
2 changes: 1 addition & 1 deletion dbt_adbs_test_project/profiles.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dbt_test:
service: "{{ env_var('DBT_ORACLE_SERVICE') }}"
#database: "{{ env_var('DBT_ORACLE_DATABASE') }}"
schema: "{{ env_var('DBT_ORACLE_SCHEMA') }}"
oml_auth_token_uri: "{{ env_var('DBT_ORACLE_OML_AUTH_TOKEN_API')}}"
oml_cloud_service_url: "{{ env_var('DBT_ORACLE_OML_CLOUD_SERVICE_URL')}}"
retry_count: 1
retry_delay: 5
shardingkey:
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dbt-core==1.5.1
dbt-core==1.5.2
cx_Oracle==8.3.0
oracledb==1.3.1
oracledb==1.3.2

2 changes: 1 addition & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tox
coverage
twine
pytest
dbt-tests-adapter==1.5.1
dbt-tests-adapter==1.5.2
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ zip_safe = False
packages = find:
include_package_data = True
install_requires =
dbt-core==1.5.1
dbt-core==1.5.2
cx_Oracle==8.3.0
oracledb==1.3.1
oracledb==1.3.2
test_suite=tests
test_requires =
dbt-tests-adapter==1.5.1
dbt-tests-adapter==1.5.2
pytest
scripts =
bin/create-pem-from-p12
Expand Down
Loading

0 comments on commit 45c7a27

Please sign in to comment.