Skip to content

Commit

Permalink
Make fetch configurable and fix hook tests (#204)
Browse files Browse the repository at this point in the history
### Summary

Before this change, execution time was idling even after the job was
finished in Dremio. This was happening because the adapter would fetch
unnecessary data from the materialized model.

### Description

This change makes it so the adapter only fetches data from the
materialized model if fetch is set to true.

### Test Results

All tests pass

### Changelog

-   [x] Added a summary of what this PR accomplishes to CHANGELOG.md

### Related Issue

#176
  • Loading branch information
ravjotbrar authored Nov 14, 2023
1 parent a7dd4fa commit e6ffb93
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
5 changes: 3 additions & 2 deletions dbt/adapters/dremio/api/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def close(self):
self._initialize()
self.closed = True

def execute(self, sql, bindings=None):
def execute(self, sql, bindings=None, fetch=True):
if self.closed:
raise Exception("CursorClosed")
if bindings is None:
Expand All @@ -88,7 +88,8 @@ def execute(self, sql, bindings=None):
self._job_id = json_payload["id"]

self._populate_rowcount()
self._populate_job_results()
if fetch:
self._populate_job_results()
self._populate_results_table()

else:
Expand Down
16 changes: 8 additions & 8 deletions dbt/adapters/dremio/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def exception_handler(self, sql):

@classmethod
def open(cls, connection):

if connection.state == "open":
logger.debug("Connection is already open, skipping open.")
return connection
Expand Down Expand Up @@ -130,13 +129,14 @@ def add_commit_query(self):
pass

# Auto_begin may not be relevant with the rest_api
def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):

def add_query(
self, sql, auto_begin=True, bindings=None, abridge_sql_log=False, fetch=True
):
connection = self.get_thread_connection()
if auto_begin and connection.transaction_open is False:
self.begin()

logger.debug(f'Using {self.TYPE} connection "{connection.name}"')
logger.debug(f'Using {self.TYPE} connection "{connection.name}". fetch={fetch}')

with self.exception_handler(sql):
if abridge_sql_log:
Expand All @@ -148,10 +148,10 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False):
cursor = connection.handle.cursor()

if bindings is None:
cursor.execute(sql)
cursor.execute(sql, fetch=fetch)
else:
logger.debug(f"Bindings: {bindings}")
cursor.execute(sql, bindings)
cursor.execute(sql, bindings, fetch=fetch)

logger.debug(
"SQL status: {} in {:0.2f} seconds".format(
Expand All @@ -174,9 +174,9 @@ def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
) -> Tuple[AdapterResponse, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
_, cursor = self.add_query(sql, auto_begin, fetch=fetch)
response = self.get_response(cursor)
fetch = True
# fetch = True
if fetch:
table = cursor.table
else:
Expand Down
28 changes: 24 additions & 4 deletions tests/hooks/test_model_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from dbt.tests.adapter.hooks.test_model_hooks import (
TestHooksRefsOnSeeds,
TestPrePostModelHooksOnSeeds,
TestPrePostModelHooksOnSnapshots,
TestDuplicateHooksInConfigs,
)

Expand Down Expand Up @@ -229,7 +227,15 @@ def test_pre_post_model_hooks_refed(self, project, dbt_profile_target):
self.check_hooks("end", project, dbt_profile_target.get("host", None))


class TestPrePostModelHooksOnSeedsDremio(TestPrePostModelHooksOnSeeds):
class TestPrePostModelHooksOnSeedsDremio(object):
@pytest.fixture(scope="class")
def seeds(self):
return {"example_seed.csv": seeds__example_seed_csv}

@pytest.fixture(scope="class")
def models(self):
return {"schema.yml": properties__seed_models}

@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand All @@ -246,6 +252,12 @@ def project_config_update(self):
},
}

def test_hooks_on_seeds(self, project):
res = run_dbt(["seed"])
assert len(res) == 1, "Expected exactly one item"
res = run_dbt(["test"])
assert len(res) == 1, "Expected exactly one item"


class TestHooksRefsOnSeedsDremio(TestHooksRefsOnSeeds):
@pytest.fixture(scope="class")
Expand Down Expand Up @@ -289,7 +301,7 @@ def project_config_update(self):
}


class TestPrePostModelHooksOnSnapshotsDremio(TestPrePostModelHooksOnSnapshots):
class TestPrePostModelHooksOnSnapshotsDremio(object):
@pytest.fixture(scope="class")
def unique_schema(self, request, prefix) -> str:
test_file = request.module.__name__
Expand Down Expand Up @@ -327,6 +339,14 @@ def setUp(self, project):
Path.mkdir(path)
write_file(snapshots__test_snapshot, path, "snapshot.sql")

@pytest.fixture(scope="class")
def models(self):
return {"schema.yml": properties__test_snapshot_models}

@pytest.fixture(scope="class")
def seeds(self):
return {"example_seed.csv": seeds__example_seed_csv}

@pytest.fixture(scope="class")
def project_config_update(self):
return {
Expand Down

0 comments on commit e6ffb93

Please sign in to comment.