Skip to content

Commit

Permalink
Inline limit in SQL sent from dbt show (#8641)
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk authored Sep 27, 2023
1 parent 997f839 commit a2d4424
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 32 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230913-153924.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: 'update dbt show to include limit in DWH query '
time: 2023-09-13T15:39:24.591805+01:00
custom:
Author: michelleark
Issue: 8496, 8417
24 changes: 24 additions & 0 deletions core/dbt/include/global_project/macros/adapters/show.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{% macro get_show_sql(compiled_code, sql_header, limit) -%}
{%- set sql_header = sql_header -%}
{{ sql_header if sql_header is not none }}
{%- if sql_header -%}
{{ sql_header }}
{%- endif -%}
{%- if limit is not none -%}
{{ get_limit_subquery_sql(compiled_code, limit) }}
{%- else -%}
{{ compiled_code }}
{%- endif -%}
{% endmacro %}

{% macro get_limit_subquery_sql(sql, limit) %}
{{ adapter.dispatch('get_limit_subquery_sql', 'dbt')(sql, limit) }}
{% endmacro %}

{% macro default__get_limit_subquery_sql(sql, limit) %}
select *
from (
{{ sql }}
) as model_limit_subq
limit {{ limit }}
{% endmacro %}
20 changes: 14 additions & 6 deletions core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import time

from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.nodes import SeedNode
from dbt.contracts.results import RunResult, RunStatus
from dbt.events.base_types import EventLevel
Expand All @@ -23,14 +24,21 @@ def execute(self, compiled_node, manifest):
# Allow passing in -1 (or any negative number) to get all rows
limit = None if self.config.args.limit < 0 else self.config.args.limit

if "sql_header" in compiled_node.unrendered_config:
compiled_node.compiled_code = (
compiled_node.unrendered_config["sql_header"] + compiled_node.compiled_code
)

model_context = generate_runtime_model_context(compiled_node, self.config, manifest)
compiled_node.compiled_code = self.adapter.execute_macro(
macro_name="get_show_sql",
manifest=manifest,
context_override=model_context,
kwargs={
"compiled_code": model_context["compiled_code"],
"sql_header": model_context["config"].get("sql_header"),
"limit": limit,
},
)
adapter_response, execute_result = self.adapter.execute(
compiled_node.compiled_code, fetch=True, limit=limit
compiled_node.compiled_code, fetch=True
)

end_time = time.time()

return RunResult(
Expand Down
34 changes: 34 additions & 0 deletions tests/adapter/dbt/tests/adapter/dbt_show/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
models__sql_header = """
{% call set_sql_header(config) %}
set session time zone '{{ var("timezone", "Europe/Paris") }}';
{%- endcall %}
select current_setting('timezone') as timezone
"""

models__ephemeral_model = """
{{ config(materialized = 'ephemeral') }}
select
coalesce(sample_num, 0) + 10 as col_deci
from {{ ref('sample_model') }}
"""

models__second_ephemeral_model = """
{{ config(materialized = 'ephemeral') }}
select
col_deci + 100 as col_hundo
from {{ ref('ephemeral_model') }}
"""

models__sample_model = """
select * from {{ ref('sample_seed') }}
"""

seeds__sample_seed = """sample_num,sample_bool
1,true
2,false
3,true
4,false
5,true
6,false
7,true
"""
62 changes: 62 additions & 0 deletions tests/adapter/dbt/tests/adapter/dbt_show/test_dbt_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from dbt.tests.util import run_dbt

from dbt.tests.adapter.dbt_show.fixtures import (
models__sql_header,
models__ephemeral_model,
models__second_ephemeral_model,
models__sample_model,
seeds__sample_seed,
)


# -- Below we define base classes for tests you import based on if your adapter supports dbt show or not --
class BaseShowLimit:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": models__sample_model,
"ephemeral_model.sql": models__ephemeral_model,
}

@pytest.fixture(scope="class")
def seeds(self):
return {"sample_seed.csv": seeds__sample_seed}

@pytest.mark.parametrize(
"args,expected",
[
([], 5), # default limit
(["--limit", 3], 3), # fetch 3 rows
(["--limit", -1], 7), # fetch all rows
],
)
def test_limit(self, project, args, expected):
run_dbt(["build"])
dbt_args = ["show", "--inline", models__second_ephemeral_model, *args]
results = run_dbt(dbt_args)
assert len(results.results[0].agate_table) == expected
# ensure limit was injected in compiled_code when limit specified in command args
limit = results.args.get("limit")
if limit > 0:
assert f"limit {limit}" in results.results[0].node.compiled_code


class BaseShowSqlHeader:
@pytest.fixture(scope="class")
def models(self):
return {
"sql_header.sql": models__sql_header,
}

def test_sql_header(self, project):
run_dbt(["build", "--vars", "timezone: Asia/Kolkata"])
run_dbt(["show", "--select", "sql_header", "--vars", "timezone: Asia/Kolkata"])


class TestPostgresShowSqlHeader(BaseShowSqlHeader):
pass


class TestPostgresShowLimit(BaseShowLimit):
pass
2 changes: 1 addition & 1 deletion tests/functional/show/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

models__sql_header = """
{% call set_sql_header(config) %}
set session time zone 'Asia/Kolkata';
set session time zone '{{ var("timezone", "Europe/Paris") }}';
{%- endcall %}
select current_setting('timezone') as timezone
"""
Expand Down
25 changes: 0 additions & 25 deletions tests/functional/show/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
models__second_model,
models__ephemeral_model,
schema_yml,
models__sql_header,
private_model_yml,
)

Expand All @@ -25,7 +24,6 @@ def models(self):
"sample_number_model_with_nulls.sql": models__sample_number_model_with_nulls,
"second_model.sql": models__second_model,
"ephemeral_model.sql": models__ephemeral_model,
"sql_header.sql": models__sql_header,
}

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -147,35 +145,12 @@ def test_second_ephemeral_model(self, project):
assert "col_hundo" in log_output


class TestShowLimit(ShowBase):
@pytest.mark.parametrize(
"args,expected",
[
([], 5), # default limit
(["--limit", 3], 3), # fetch 3 rows
(["--limit", -1], 7), # fetch all rows
],
)
def test_limit(self, project, args, expected):
run_dbt(["build"])
dbt_args = ["show", "--inline", models__second_ephemeral_model, *args]
results = run_dbt(dbt_args)
assert len(results.results[0].agate_table) == expected


class TestShowSeed(ShowBase):
def test_seed(self, project):
(_, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"])
assert "Previewing node 'sample_seed'" in log_output


class TestShowSqlHeader(ShowBase):
def test_sql_header(self, project):
run_dbt(["build"])
(_, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"])
assert "Asia/Kolkata" in log_output


class TestShowModelVersions:
@pytest.fixture(scope="class")
def models(self):
Expand Down

0 comments on commit a2d4424

Please sign in to comment.