Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Database): Add support for parsing booleans #508

Merged
merged 7 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parsons/databases/database/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
VARCHAR = "varchar"
FLOAT = "float"

DO_PARSE_BOOLS = False
BOOL = "bool"
TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1)
FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0)

# The following values are the minimum and maximum values for MySQL int
# types. https://dev.mysql.com/doc/refman/8.0/en/integer-types.html
SMALLINT = "smallint"
Expand Down
44 changes: 41 additions & 3 deletions parsons/databases/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self):
self.INT_MAX = consts.INT_MAX
self.BIGINT = consts.BIGINT
self.FLOAT = consts.FLOAT

# Added for backwards compatability
self.DO_PARSE_BOOLS = consts.DO_PARSE_BOOLS
self.BOOL = consts.BOOL
self.TRUE_VALS = consts.TRUE_VALS
self.FALSE_VALS = consts.FALSE_VALS

self.VARCHAR = consts.VARCHAR
self.RESERVED_WORDS = consts.RESERVED_WORDS
self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN
Expand Down Expand Up @@ -97,7 +104,8 @@ def is_valid_sql_num(self, val):
# then it's a valid sql number
# Also check the first character is not zero
try:
if (float(val) or int(val)) and "_" not in val and val[0] != "0":
if ((float(val) or 1) and "_" not in val and
(val in ("0", "0.0") or val[0] != "0")):
return True
else:
return False
Expand All @@ -107,6 +115,28 @@ def is_valid_sql_num(self, val):
except (TypeError, ValueError):
return False

def is_sql_bool(self, val):
"""Check whether val is a valid sql boolean.
When inserting data into databases, different values can be accepted
as boolean types. For excample, ``False``, ``'FALSE'``, ``1``.
`Args`:
val: any
The value to check.
`Returns`:
bool
Whether or not the value is a valid sql boolean.
"""
if not self.DO_PARSE_BOOLS:
return

if isinstance(val, bool) or (
type(val) in (int, str) and
str(val).upper() in self.TRUE_VALS + self.FALSE_VALS):
return True
return False

def detect_data_type(self, value, cmp_type=None):
Copy link
Contributor

@ChrisC ChrisC Jun 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easier to toggle the Boolean parsing on or off (especially in the test suites) if we pass it as an optional argument in this function, instead of the config setting on the base class. Something like:
def detect_data_type(self, value, cmp_type=None, do_parse_bools=False):

"""Detect the higher of value's type cmp_type.
Expand Down Expand Up @@ -139,6 +169,8 @@ def detect_data_type(self, value, cmp_type=None):
try:
val_lit = ast.literal_eval(str(value))
except (SyntaxError, ValueError):
if self.is_sql_bool(value):
return self.BOOL
return self.VARCHAR

# Exit early if it's None
Expand All @@ -152,7 +184,13 @@ def detect_data_type(self, value, cmp_type=None):
# Python accepts 100_000 as a valid form of 100000,
# however a sql engine may throw an error
if not self.is_valid_sql_num(value):
return self.VARCHAR
if self.is_sql_bool(val_lit) and cmp_type != self.VARCHAR:
return self.BOOL
else:
return self.VARCHAR

if self.is_sql_bool(val_lit) and cmp_type not in self.INT_TYPES + [self.FLOAT]:
return self.BOOL

type_lit = type(val_lit)

Expand All @@ -164,7 +202,7 @@ def detect_data_type(self, value, cmp_type=None):
# The value is very likely an int
# let's get its size
# If the compare types are empty and use the types of the current value
if type_lit == int and cmp_type in (self.INT_TYPES + [None, ""]):
if type_lit == int and cmp_type in (self.INT_TYPES + [None, "", self.BOOL]):

# Use smallest possible int type above TINYINT
if (self.SMALLINT_MIN < val_lit < self.SMALLINT_MAX):
Expand Down
3 changes: 2 additions & 1 deletion parsons/databases/redshift/rs_create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def __init__(self):
self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN
self.REPLACE_CHARS = consts.REPLACE_CHARS

# Redshift doesn't have a medium int
# Currently smallints are coded as ints
self.SMALLINT = self.INT
# Redshift doesn't have a medium int
self.MEDIUMINT = self.INT

# Currently py floats are coded as Redshift decimals
Expand Down
36 changes: 34 additions & 2 deletions test/test_databases/test_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parsons.databases.database.constants import (
SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, VARCHAR)
SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, BOOL, VARCHAR)

from parsons.databases.database.database import DatabaseCreateStatement

Expand All @@ -8,7 +8,15 @@

@pytest.fixture
def dcs():
return DatabaseCreateStatement()
db = DatabaseCreateStatement()
return db


@pytest.fixture
def dcs_bool():
db = DatabaseCreateStatement()
db.DO_PARSE_BOOLS = True
ChrisC marked this conversation as resolved.
Show resolved Hide resolved
return db


@pytest.mark.parametrize(
Expand Down Expand Up @@ -49,6 +57,8 @@ def test_get_bigger_int(dcs, int1, int2, higher):
(+1.2, True),
(+1., True),
(+1.0_0, True),
(0, True),
(0.0, True),
("10", True),
("1_0", False),
("+10", True),
Expand All @@ -60,6 +70,8 @@ def test_get_bigger_int(dcs, int1, int2, higher):
("+1.2", True),
("+1.", True),
("+1.0_0", False),
("0", True),
("0.0", True),
(True, False),
("True", False),
("a string", False),
Expand All @@ -83,6 +95,7 @@ def test_is_valid_sql_num(dcs, val, is_valid):
(2147483648, FLOAT, FLOAT),
(5.001, None, FLOAT),
(5.001, "", FLOAT),
("FALSE", VARCHAR, VARCHAR),
("word", "", VARCHAR),
("word", INT, VARCHAR),
("1_2", BIGINT, VARCHAR),
Expand All @@ -91,11 +104,30 @@ def test_is_valid_sql_num(dcs, val, is_valid):
("word", None, VARCHAR),
("1_2", None, VARCHAR),
("01", None, VARCHAR),
("{}", None, VARCHAR),
))
def test_detect_data_type(dcs, val, cmp_type, detected_type):
assert dcs.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
("val", "cmp_type", "detected_type"),
((2, None, SMALLINT),
(2, "", SMALLINT),
(1, MEDIUMINT, MEDIUMINT),
(2, BOOL, SMALLINT),
(True, None, BOOL),
(0, None, BOOL),
(1, None, BOOL),
(1, BOOL, BOOL),
("F", None, BOOL),
("FALSE", None, BOOL),
("Yes", None, BOOL)
))
def test_detect_data_type_bools(dcs_bool, val, cmp_type, detected_type):
assert dcs_bool.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
("col", "renamed"),
(("a", "a"),
Expand Down
15 changes: 12 additions & 3 deletions test/test_databases/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parsons.databases.mysql.mysql import MySQL
from parsons.databases.table import BaseTable
from parsons.databases.mysql.create_table import MySQLCreateTable
from parsons.etl.table import Table
from test.utils import assert_matching_tables
import unittest
Expand Down Expand Up @@ -64,7 +64,7 @@ def setUp(self):
('you', 'hey', '3')
""")

self.tbl = BaseTable(self.mysql, 'test')
self.tbl = MySQLCreateTable(self.mysql, 'test')

def tearDown(self):

Expand All @@ -91,7 +91,7 @@ def test_exists(self):

self.assertTrue(self.tbl.exists)

tbl_bad = BaseTable(self.mysql, 'bad_test')
tbl_bad = MySQLCreateTable(self.mysql, 'bad_test')
self.assertFalse(tbl_bad.exists)

def test_drop(self):
Expand Down Expand Up @@ -145,14 +145,23 @@ def setUp(self):

def test_data_type(self):

# Test bool
self.mysql.DO_PARSE_BOOLS = True
self.assertEqual(self.mysql.data_type(1, ''), 'bool')
self.assertEqual(self.mysql.data_type(False, ''), 'bool')

self.mysql.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.mysql.data_type(1, ''), 'smallint')
self.assertEqual(self.mysql.data_type(2, ''), 'smallint')
# Test int
self.assertEqual(self.mysql.data_type(32769, ''), 'mediumint')
# Test bigint
self.assertEqual(self.mysql.data_type(2147483648, ''), 'bigint')
# Test varchar that looks like an int
self.assertEqual(self.mysql.data_type('00001', ''), 'varchar')
# Test varchar that looks like a bool
self.assertEqual(self.mysql.data_type(False, ''), 'varchar')
# Test a float as a decimal
self.assertEqual(self.mysql.data_type(5.001, ''), 'float')
# Test varchar
Expand Down
15 changes: 14 additions & 1 deletion test/test_databases/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def setUp(self):

self.mapping = self.pg.generate_data_types(self.tbl)
self.mapping2 = self.pg.generate_data_types(self.tbl2)
self.pg.DO_PARSE_BOOLS = True
self.mapping3 = self.pg.generate_data_types(self.tbl2)

def test_connection(self):

Expand All @@ -55,15 +57,18 @@ def test_connection(self):
self.assertEqual(pg_env.port, 5432)

def test_data_type(self):

self.pg.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.pg.data_type(1, ''), 'smallint')
self.assertEqual(self.pg.data_type(2, ''), 'smallint')
# Test int
self.assertEqual(self.pg.data_type(32769, ''), 'int')
# Test bigint
self.assertEqual(self.pg.data_type(2147483648, ''), 'bigint')
# Test varchar that looks like an int
self.assertEqual(self.pg.data_type('00001', ''), 'varchar')
# Test varchar that looks like a bool
self.assertEqual(self.pg.data_type(True, ''), 'varchar')
# Test a float as a decimal
self.assertEqual(self.pg.data_type(5.001, ''), 'decimal')
# Test varchar
Expand All @@ -73,6 +78,11 @@ def test_data_type(self):
# Test int with leading zero
self.assertEqual(self.pg.data_type('01', ''), 'varchar')

# Test bool
self.pg.DO_PARSE_BOOLS = True
self.assertEqual(self.pg.data_type(1, ''), 'bool')
self.assertEqual(self.pg.data_type(True, ''), 'bool')

def test_generate_data_types(self):

# Test correct header labels
Expand All @@ -82,6 +92,9 @@ def test_generate_data_types(self):
self.assertEqual(
self.mapping2['type_list'],
['varchar', 'varchar', 'decimal', 'varchar', "decimal", "smallint", "varchar"])
self.assertEqual(
self.mapping3['type_list'],
['varchar', 'varchar', 'decimal', 'varchar', "decimal", "bool", "varchar"])
# Test correct lengths
self.assertEqual(self.mapping['longest'], [1, 5])

Expand Down
24 changes: 20 additions & 4 deletions test/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def setUp(self):
])

self.mapping = self.rs.generate_data_types(self.tbl)
self.rs.DO_PARSE_BOOLS = True
self.mapping2 = self.rs.generate_data_types(self.tbl2)
self.rs.DO_PARSE_BOOLS = False
self.mapping3 = self.rs.generate_data_types(self.tbl2)

def test_split_full_table_name(self):
schema, table = Redshift.split_full_table_name('some_schema.some_table')
Expand All @@ -58,8 +61,16 @@ def test_combine_schema_and_table_name(self):

def test_data_type(self):

# Test int
# Test bool
self.rs.DO_PARSE_BOOLS = True
self.assertEqual(self.rs.data_type(1, ''), 'bool')
self.assertEqual(self.rs.data_type(True, ''), 'bool')
self.rs.DO_PARSE_BOOLS = False
self.assertEqual(self.rs.data_type(1, ''), 'int')
self.assertEqual(self.rs.data_type(True, ''), 'varchar')
# Test smallint
# Currently smallints are coded as ints
self.assertEqual(self.rs.data_type(2, ''), 'int')
# Test int
self.assertEqual(self.rs.data_type(32769, ''), 'int')
# Test bigint
Expand All @@ -84,7 +95,11 @@ def test_generate_data_types(self):

self.assertEqual(
self.mapping2['type_list'],
['varchar', 'varchar', 'float', 'varchar', "float", "int", "varchar"])
['varchar', 'varchar', 'float', 'varchar', 'float', 'bool', 'varchar'])

self.assertEqual(
self.mapping3['type_list'],
['varchar', 'varchar', 'float', 'varchar', 'float', 'int', 'varchar'])
# Test correct lengths
self.assertEqual(self.mapping['longest'], [1, 5])

Expand Down Expand Up @@ -132,8 +147,9 @@ def test_column_validate(self):
bad_cols = ['a', 'a', '', 'SELECT', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaksdfjklasj'
'dfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjklasdfjklasjkl'
'dfakljsdfjalsdkfjklasjdfklasjdfklasdkljf']
fixed_cols = ['a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks'
'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl']
fixed_cols = [
'a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks'
'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl']
self.assertEqual(self.rs.column_name_validate(bad_cols), fixed_cols)

def test_create_statement(self):
Expand Down