-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
One hot encoder enhancement: Conoromara/one hot encoder (#4)
Significant changes to the OneHotEncoder macro, hence new major version * Output columns now follow Gitlab SQL naming conventions * Fix case where category values contain whitespace * Provide flexibility with excluding source table columns * Support scikit-learn's handle_unknown strategy of 'error', this is now the default * Update doco, bump version Also added a couple of Redshift fixes to get all tests to pass. Co-authored-by: James Weakley <jameswillisweakley@gmail.com>
- Loading branch information
1 parent
32bc366
commit d267f23
Showing
12 changed files
with
124 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
name: 'dbt_ml_preprocessing' | ||
version: '0.7.0' | ||
version: '1.0.0' | ||
|
||
require-dbt-version: ">=0.15.1" | ||
|
||
|
2 changes: 1 addition & 1 deletion
2
integration_tests/data/sql/data_one_hot_encoder_category_selected_expected.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,102 @@ | ||
{% macro one_hot_encoder(source_table,source_column,categories='auto',handle_unknown='ignore',include_columns='*') %} | ||
{%- if categories=='auto' -%} | ||
{% set category_values_query %} | ||
select distinct {{ source_column }} from {{ source_table }} | ||
order by 1 | ||
{% endset %} | ||
{% set results = run_query(category_values_query) %} | ||
{% if execute %} | ||
{# Return the first column #} | ||
{% set category_values = results.columns[0].values() %} | ||
{% else %} | ||
{% set category_values = [] %} | ||
{% endif %} | ||
{% elif categories is not iterable or categories is string or categories is mapping %} | ||
{% set error_message %} | ||
The `categories` parameter must contain a list of category values. | ||
{% endset %} | ||
{% macro one_hot_encoder(source_table, source_column, categories='auto', handle_unknown='error',include_columns='*', exclude_columns=none) %} | ||
|
||
{%- if categories=='auto' -%} | ||
{% set category_values_query %} | ||
select distinct | ||
{{ source_column }} | ||
from | ||
{{ source_table }} | ||
order by 1 | ||
{% endset %} | ||
{% set results = run_query(category_values_query) %} | ||
{% if execute %} | ||
{# Return the first column #} | ||
{% set category_values = results.columns[0].values() %} | ||
{% else %} | ||
{% set category_values = [] %} | ||
{% endif %} | ||
{% elif categories is not iterable or categories is string or categories is mapping %} | ||
{% set error_message %} | ||
The `categories` parameter must contain a list of category values. | ||
{% endset %} | ||
{%- do exceptions.raise_compiler_error(error_message) -%} | ||
{%- else -%} | ||
{% set category_values = categories %} | ||
{%- endif -%} | ||
|
||
{%- if handle_unknown!='ignore' and handle_unknown!='error' -%} | ||
{% set error_message %} | ||
The 'handle_unknown' parameter requires a value of either 'ignore' (when unknown value occurs, all output columns are false) or 'error' (when unknown value occurs, raise an error). | ||
{% endset %} | ||
{%- do exceptions.raise_compiler_error(error_message) -%} | ||
{%- endif -%} | ||
|
||
{%- if include_columns!='*' and exclude_columns is not none -%} | ||
{% set error_message %} | ||
If the 'exclude_columns' parameter is set, providing 'include_columns' is invalid and must be left at its default value. | ||
{% endset %} | ||
{%- do exceptions.raise_compiler_error(error_message) -%} | ||
{%- endif -%} | ||
|
||
{%- if exclude_columns is not none and (exclude_columns is not iterable or exclude_columns is string or exclude_columns is mapping) -%} | ||
{% set error_message %} | ||
The 'exclude_columns' parameter value contain a list of column names. | ||
{% endset %} | ||
{%- do exceptions.raise_compiler_error(error_message) -%} | ||
{%- else -%} | ||
{% set category_values = categories %} | ||
{%- endif -%} | ||
{%- if handle_unknown!='ignore' -%} | ||
{% set error_message %} | ||
The `one_hot_encoder` macro only supports an 'handle_unknown' value of 'ignore' at this time. | ||
{% endset %} | ||
{%- endif -%} | ||
|
||
{%- if include_columns!='*' and (include_columns is not iterable or include_columns is string or include_columns is mapping) -%} | ||
{% set error_message %} | ||
The 'include_columns' parameter value must contain either the string '*' (for all columns in source), or a list of column names. | ||
{% endset %} | ||
{%- do exceptions.raise_compiler_error(error_message) -%} | ||
{%- endif -%} | ||
{{ adapter.dispatch('one_hot_encoder',packages=['dbt_ml_preprocessing'])(source_table,source_column,category_values,handle_unknown,include_columns) }} | ||
{%- endmacro %} | ||
{%- endif -%} | ||
|
||
{% macro snowflake__one_hot_encoder(source_table,source_column,category_values,handle_unknown,include_columns) %} | ||
select | ||
{% for column in include_columns %} | ||
{{ source_table }}.{{ column }}, | ||
{% endfor %} | ||
{% for category in category_values %} | ||
iff({{source_column}}='{{category}}',true,false) as {{source_column}}_{{category}} | ||
{% if not loop.last %}, {% endif %} | ||
{% endfor %} | ||
from {{ source_table }} | ||
{{ adapter.dispatch('one_hot_encoder',packages=['dbt_ml_preprocessing'])(source_table, source_column, category_values, handle_unknown, include_columns, exclude_columns) }} | ||
{%- endmacro %} | ||
|
||
{% macro default__one_hot_encoder(source_table,source_column,category_values,handle_unknown,include_columns) %} | ||
select | ||
{% for column in include_columns %} | ||
{{ column }}, | ||
{% endfor %} | ||
{% for category in category_values %} | ||
case {{source_column}} | ||
when '{{category}}' then true | ||
else false | ||
end as {{source_column}}_{{category}} | ||
{% if not loop.last %}, {% endif %} | ||
{% endfor %} | ||
from {{ source_table }} | ||
{% macro default__one_hot_encoder(source_table, source_column, category_values, handle_unknown, include_columns, exclude_columns) %} | ||
{% set columns = adapter.get_columns_in_relation( source_table ) %} | ||
|
||
|
||
|
||
|
||
with binary_output as ( | ||
select | ||
{%- if include_columns=='*' and exclude_columns is none -%} | ||
{% for column in columns %} | ||
{{ column.name }}, | ||
{%- endfor -%} | ||
{%- elif include_columns !='*'-%} | ||
{% for column in include_columns %} | ||
{{ source_table }}.{{ column }}, | ||
{%- endfor -%} | ||
{%- else -%} | ||
{% for column in columns %} | ||
{%- if column.name | lower not in exclude_columns | lower %} | ||
{{ column.name }}, | ||
{%- endif -%} | ||
{%- endfor -%} | ||
{%- endif -%} | ||
{% for category in category_values %} | ||
{% set no_whitespace_column_name = category | replace( " ", "_") -%} | ||
{%- if handle_unknown=='ignore' %} | ||
case | ||
when {{ source_column }} = '{{ category }}' then true | ||
else false | ||
end as is_{{ source_column }}_{{ no_whitespace_column_name }} | ||
{% endif %} | ||
{%- if handle_unknown=='error' %} | ||
case | ||
when {{ source_column }} = '{{ category }}' then true | ||
when {{ source_column }} in ('{{ category_values | join("','") }}') then false | ||
else cast('Error: unknown value found and handle_unknown parameter was "error"' as boolean) | ||
end as is_{{ source_column }}_{{ no_whitespace_column_name }} | ||
{% endif %} | ||
{%- if not loop.last %},{% endif -%} | ||
{% endfor %} | ||
from {{ source_table }} | ||
) | ||
|
||
select * from binary_output | ||
{%- endmacro %} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters