Skip to content

Commit

Permalink
feat(Database): Add support for parsing booleans (move-coop#508)
Browse files Browse the repository at this point in the history
* feat(Database): Add support for parsing booleans

* Add attribute to set/unset bool parsing

* Update to get DO_PARSE_BOOLS from constants file

* Fix lint errors

* fix merge from main

* explicitly set boolean parsing in DB tests

Co-authored-by: Chris Cuellar <58723+ChrisC@users.noreply.github.com>
  • Loading branch information
dannyboy15 and ChrisC authored Dec 16, 2021
1 parent e5ff106 commit 766cfae
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 14 deletions.
5 changes: 5 additions & 0 deletions parsons/databases/database/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
VARCHAR = "varchar"
FLOAT = "float"

DO_PARSE_BOOLS = False
BOOL = "bool"
TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1)
FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0)

# The following values are the minimum and maximum values for MySQL int
# types. https://dev.mysql.com/doc/refman/8.0/en/integer-types.html
SMALLINT = "smallint"
Expand Down
44 changes: 41 additions & 3 deletions parsons/databases/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def __init__(self):
self.INT_MAX = consts.INT_MAX
self.BIGINT = consts.BIGINT
self.FLOAT = consts.FLOAT

# Added for backwards compatability
self.DO_PARSE_BOOLS = consts.DO_PARSE_BOOLS
self.BOOL = consts.BOOL
self.TRUE_VALS = consts.TRUE_VALS
self.FALSE_VALS = consts.FALSE_VALS

self.VARCHAR = consts.VARCHAR
self.RESERVED_WORDS = consts.RESERVED_WORDS
self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN
Expand Down Expand Up @@ -97,7 +104,8 @@ def is_valid_sql_num(self, val):
# then it's a valid sql number
# Also check the first character is not zero
try:
if (float(val) or int(val)) and "_" not in val and val[0] != "0":
if ((float(val) or 1) and "_" not in val and
(val in ("0", "0.0") or val[0] != "0")):
return True
else:
return False
Expand All @@ -107,6 +115,28 @@ def is_valid_sql_num(self, val):
except (TypeError, ValueError):
return False

def is_sql_bool(self, val):
"""Check whether val is a valid sql boolean.
When inserting data into databases, different values can be accepted
as boolean types. For excample, ``False``, ``'FALSE'``, ``1``.
`Args`:
val: any
The value to check.
`Returns`:
bool
Whether or not the value is a valid sql boolean.
"""
if not self.DO_PARSE_BOOLS:
return

if isinstance(val, bool) or (
type(val) in (int, str) and
str(val).upper() in self.TRUE_VALS + self.FALSE_VALS):
return True
return False

def detect_data_type(self, value, cmp_type=None):
"""Detect the higher of value's type cmp_type.
Expand Down Expand Up @@ -139,6 +169,8 @@ def detect_data_type(self, value, cmp_type=None):
try:
val_lit = ast.literal_eval(str(value))
except (SyntaxError, ValueError):
if self.is_sql_bool(value):
return self.BOOL
return self.VARCHAR

# Exit early if it's None
Expand All @@ -152,7 +184,13 @@ def detect_data_type(self, value, cmp_type=None):
# Python accepts 100_000 as a valid form of 100000,
# however a sql engine may throw an error
if not self.is_valid_sql_num(value):
return self.VARCHAR
if self.is_sql_bool(val_lit) and cmp_type != self.VARCHAR:
return self.BOOL
else:
return self.VARCHAR

if self.is_sql_bool(val_lit) and cmp_type not in self.INT_TYPES + [self.FLOAT]:
return self.BOOL

type_lit = type(val_lit)

Expand All @@ -164,7 +202,7 @@ def detect_data_type(self, value, cmp_type=None):
# The value is very likely an int
# let's get its size
# If the compare types are empty and use the types of the current value
if type_lit == int and cmp_type in (self.INT_TYPES + [None, ""]):
if type_lit == int and cmp_type in (self.INT_TYPES + [None, "", self.BOOL]):

# Use smallest possible int type above TINYINT
if (self.SMALLINT_MIN < val_lit < self.SMALLINT_MAX):
Expand Down
3 changes: 2 additions & 1 deletion parsons/databases/redshift/rs_create_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ def __init__(self):
self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN
self.REPLACE_CHARS = consts.REPLACE_CHARS

# Redshift doesn't have a medium int
# Currently smallints are coded as ints
self.SMALLINT = self.INT
# Redshift doesn't have a medium int
self.MEDIUMINT = self.INT

# Currently py floats are coded as Redshift decimals
Expand Down
36 changes: 34 additions & 2 deletions test/test_databases/test_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parsons.databases.database.constants import (
SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, VARCHAR)
SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, BOOL, VARCHAR)

from parsons.databases.database.database import DatabaseCreateStatement

Expand All @@ -8,7 +8,15 @@

@pytest.fixture
def dcs():
return DatabaseCreateStatement()
db = DatabaseCreateStatement()
return db


@pytest.fixture
def dcs_bool():
db = DatabaseCreateStatement()
db.DO_PARSE_BOOLS = True
return db


@pytest.mark.parametrize(
Expand Down Expand Up @@ -49,6 +57,8 @@ def test_get_bigger_int(dcs, int1, int2, higher):
(+1.2, True),
(+1., True),
(+1.0_0, True),
(0, True),
(0.0, True),
("10", True),
("1_0", False),
("+10", True),
Expand All @@ -60,6 +70,8 @@ def test_get_bigger_int(dcs, int1, int2, higher):
("+1.2", True),
("+1.", True),
("+1.0_0", False),
("0", True),
("0.0", True),
(True, False),
("True", False),
("a string", False),
Expand All @@ -83,6 +95,7 @@ def test_is_valid_sql_num(dcs, val, is_valid):
(2147483648, FLOAT, FLOAT),
(5.001, None, FLOAT),
(5.001, "", FLOAT),
("FALSE", VARCHAR, VARCHAR),
("word", "", VARCHAR),
("word", INT, VARCHAR),
("1_2", BIGINT, VARCHAR),
Expand All @@ -91,11 +104,30 @@ def test_is_valid_sql_num(dcs, val, is_valid):
("word", None, VARCHAR),
("1_2", None, VARCHAR),
("01", None, VARCHAR),
("{}", None, VARCHAR),
))
def test_detect_data_type(dcs, val, cmp_type, detected_type):
assert dcs.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
("val", "cmp_type", "detected_type"),
((2, None, SMALLINT),
(2, "", SMALLINT),
(1, MEDIUMINT, MEDIUMINT),
(2, BOOL, SMALLINT),
(True, None, BOOL),
(0, None, BOOL),
(1, None, BOOL),
(1, BOOL, BOOL),
("F", None, BOOL),
("FALSE", None, BOOL),
("Yes", None, BOOL)
))
def test_detect_data_type_bools(dcs_bool, val, cmp_type, detected_type):
assert dcs_bool.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
("col", "renamed"),
(("a", "a"),
Expand Down
15 changes: 12 additions & 3 deletions test/test_databases/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from parsons.databases.mysql.mysql import MySQL
from parsons.databases.table import BaseTable
from parsons.databases.mysql.create_table import MySQLCreateTable
from parsons.etl.table import Table
from test.utils import assert_matching_tables
import unittest
Expand Down Expand Up @@ -64,7 +64,7 @@ def setUp(self):
('you', 'hey', '3')
""")

self.tbl = BaseTable(self.mysql, 'test')
self.tbl = MySQLCreateTable(self.mysql, 'test')

def tearDown(self):

Expand All @@ -91,7 +91,7 @@ def test_exists(self):

self.assertTrue(self.tbl.exists)

tbl_bad = BaseTable(self.mysql, 'bad_test')
tbl_bad = MySQLCreateTable(self.mysql, 'bad_test')
self.assertFalse(tbl_bad.exists)

def test_drop(self):
Expand Down Expand Up @@ -145,14 +145,23 @@ def setUp(self):

def test_data_type(self):

# Test bool
self.mysql.DO_PARSE_BOOLS = True
self.assertEqual(self.mysql.data_type(1, ''), 'bool')
self.assertEqual(self.mysql.data_type(False, ''), 'bool')

self.mysql.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.mysql.data_type(1, ''), 'smallint')
self.assertEqual(self.mysql.data_type(2, ''), 'smallint')
# Test int
self.assertEqual(self.mysql.data_type(32769, ''), 'mediumint')
# Test bigint
self.assertEqual(self.mysql.data_type(2147483648, ''), 'bigint')
# Test varchar that looks like an int
self.assertEqual(self.mysql.data_type('00001', ''), 'varchar')
# Test varchar that looks like a bool
self.assertEqual(self.mysql.data_type(False, ''), 'varchar')
# Test a float as a decimal
self.assertEqual(self.mysql.data_type(5.001, ''), 'float')
# Test varchar
Expand Down
15 changes: 14 additions & 1 deletion test/test_databases/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def setUp(self):

self.mapping = self.pg.generate_data_types(self.tbl)
self.mapping2 = self.pg.generate_data_types(self.tbl2)
self.pg.DO_PARSE_BOOLS = True
self.mapping3 = self.pg.generate_data_types(self.tbl2)

def test_connection(self):

Expand All @@ -55,15 +57,18 @@ def test_connection(self):
self.assertEqual(pg_env.port, 5432)

def test_data_type(self):

self.pg.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.pg.data_type(1, ''), 'smallint')
self.assertEqual(self.pg.data_type(2, ''), 'smallint')
# Test int
self.assertEqual(self.pg.data_type(32769, ''), 'int')
# Test bigint
self.assertEqual(self.pg.data_type(2147483648, ''), 'bigint')
# Test varchar that looks like an int
self.assertEqual(self.pg.data_type('00001', ''), 'varchar')
# Test varchar that looks like a bool
self.assertEqual(self.pg.data_type(True, ''), 'varchar')
# Test a float as a decimal
self.assertEqual(self.pg.data_type(5.001, ''), 'decimal')
# Test varchar
Expand All @@ -73,6 +78,11 @@ def test_data_type(self):
# Test int with leading zero
self.assertEqual(self.pg.data_type('01', ''), 'varchar')

# Test bool
self.pg.DO_PARSE_BOOLS = True
self.assertEqual(self.pg.data_type(1, ''), 'bool')
self.assertEqual(self.pg.data_type(True, ''), 'bool')

def test_generate_data_types(self):

# Test correct header labels
Expand All @@ -82,6 +92,9 @@ def test_generate_data_types(self):
self.assertEqual(
self.mapping2['type_list'],
['varchar', 'varchar', 'decimal', 'varchar', "decimal", "smallint", "varchar"])
self.assertEqual(
self.mapping3['type_list'],
['varchar', 'varchar', 'decimal', 'varchar', "decimal", "bool", "varchar"])
# Test correct lengths
self.assertEqual(self.mapping['longest'], [1, 5])

Expand Down
24 changes: 20 additions & 4 deletions test/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def setUp(self):
])

self.mapping = self.rs.generate_data_types(self.tbl)
self.rs.DO_PARSE_BOOLS = True
self.mapping2 = self.rs.generate_data_types(self.tbl2)
self.rs.DO_PARSE_BOOLS = False
self.mapping3 = self.rs.generate_data_types(self.tbl2)

def test_split_full_table_name(self):
schema, table = Redshift.split_full_table_name('some_schema.some_table')
Expand All @@ -58,8 +61,16 @@ def test_combine_schema_and_table_name(self):

def test_data_type(self):

# Test int
# Test bool
self.rs.DO_PARSE_BOOLS = True
self.assertEqual(self.rs.data_type(1, ''), 'bool')
self.assertEqual(self.rs.data_type(True, ''), 'bool')
self.rs.DO_PARSE_BOOLS = False
self.assertEqual(self.rs.data_type(1, ''), 'int')
self.assertEqual(self.rs.data_type(True, ''), 'varchar')
# Test smallint
# Currently smallints are coded as ints
self.assertEqual(self.rs.data_type(2, ''), 'int')
# Test int
self.assertEqual(self.rs.data_type(32769, ''), 'int')
# Test bigint
Expand All @@ -84,7 +95,11 @@ def test_generate_data_types(self):

self.assertEqual(
self.mapping2['type_list'],
['varchar', 'varchar', 'float', 'varchar', "float", "int", "varchar"])
['varchar', 'varchar', 'float', 'varchar', 'float', 'bool', 'varchar'])

self.assertEqual(
self.mapping3['type_list'],
['varchar', 'varchar', 'float', 'varchar', 'float', 'int', 'varchar'])
# Test correct lengths
self.assertEqual(self.mapping['longest'], [1, 5])

Expand Down Expand Up @@ -132,8 +147,9 @@ def test_column_validate(self):
bad_cols = ['a', 'a', '', 'SELECT', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaksdfjklasj'
'dfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjklasdfjklasjkl'
'dfakljsdfjalsdkfjklasjdfklasjdfklasdkljf']
fixed_cols = ['a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks'
'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl']
fixed_cols = [
'a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks'
'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl']
self.assertEqual(self.rs.column_name_validate(bad_cols), fixed_cols)

def test_create_statement(self):
Expand Down

0 comments on commit 766cfae

Please sign in to comment.