Skip to content

Commit

Permalink
Parse Boolean types by default (move-coop#943)
Browse files Browse the repository at this point in the history
* Parse Boolean types by default

Commit 766cfae created a feature for
parsing boolean types but turned it off by default. This commit turns
that feature on by default and adds a comment about how to turn it off
and what that does.

* Fix test expectations after updating boolean parsing behavior

* Only ever interpret python bools as SQL booleans

No longer coerce by default any of the following as booleans:
"yes", "True", "t", 1, 0, "no", "False", "f"

* Fix redshift test parsing bools

* Move redshift test into test_databases folder

* Remove retired TRUE_VALS and FALSE_VALS configuration variables

We now only use python booleans
  • Loading branch information
austinweisgrau authored Dec 8, 2023
1 parent afca41c commit 14df184
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 157 deletions.
4 changes: 0 additions & 4 deletions parsons/databases/database/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,7 @@

VARCHAR = "varchar"
FLOAT = "float"

DO_PARSE_BOOLS = False
BOOL = "bool"
TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1)
FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0)

# The following values are the minimum and maximum values for MySQL int
# types. https://dev.mysql.com/doc/refman/8.0/en/integer-types.html
Expand Down
104 changes: 30 additions & 74 deletions parsons/databases/database/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import parsons.databases.database.constants as consts
import ast
import logging

logger = logging.getLogger(__name__)


class DatabaseCreateStatement:
Expand All @@ -17,11 +19,7 @@ def __init__(self):
self.BIGINT = consts.BIGINT
self.FLOAT = consts.FLOAT

# Added for backwards compatability
self.DO_PARSE_BOOLS = consts.DO_PARSE_BOOLS
self.BOOL = consts.BOOL
self.TRUE_VALS = consts.TRUE_VALS
self.FALSE_VALS = consts.FALSE_VALS

self.VARCHAR = consts.VARCHAR
self.RESERVED_WORDS = consts.RESERVED_WORDS
Expand Down Expand Up @@ -117,29 +115,6 @@ def is_valid_sql_num(self, val):
except (TypeError, ValueError):
return False

def is_sql_bool(self, val):
"""Check whether val is a valid sql boolean.
When inserting data into databases, different values can be accepted
as boolean types. For excample, ``False``, ``'FALSE'``, ``1``.
`Args`:
val: any
The value to check.
`Returns`:
bool
Whether or not the value is a valid sql boolean.
"""
if not self.DO_PARSE_BOOLS:
return

if isinstance(val, bool) or (
type(val) in (int, str)
and str(val).upper() in self.TRUE_VALS + self.FALSE_VALS
):
return True
return False

def detect_data_type(self, value, cmp_type=None):
"""Detect the higher of value's type cmp_type.
Expand All @@ -161,64 +136,45 @@ def detect_data_type(self, value, cmp_type=None):
# Stop if the compare type is already a varchar
# varchar is the highest data type.
if cmp_type == self.VARCHAR:
return cmp_type

# Attempt to evaluate value as a literal (e.g. '1' => 1, ) If the value
# is just a string, is None, or is empty, it will raise an error. These
# should be considered varchars.
# E.g.
# "" => SyntaxError
# "anystring" => ValueError
try:
val_lit = ast.literal_eval(str(value))
except (SyntaxError, ValueError):
if self.is_sql_bool(value):
return self.BOOL
return self.VARCHAR

# Exit early if it's None
# is_valid_sql_num(None) == False
# instead of defaulting to varchar (which is the next test)
# return the compare type
if val_lit is None:
return cmp_type
result = cmp_type

elif isinstance(value, bool):
result = self.BOOL

elif value is None:
result = cmp_type

# Make sure that it is a valid integer
# Python accepts 100_000 as a valid form of 100000,
# however a sql engine may throw an error
if not self.is_valid_sql_num(value):
if self.is_sql_bool(val_lit) and cmp_type != self.VARCHAR:
return self.BOOL
else:
return self.VARCHAR
elif not self.is_valid_sql_num(value):
result = self.VARCHAR

if self.is_sql_bool(val_lit) and cmp_type not in self.INT_TYPES + [self.FLOAT]:
return self.BOOL

type_lit = type(val_lit)

# If a float, stop here
# float is highest after varchar
if type_lit == float or cmp_type == self.FLOAT:
return self.FLOAT
elif isinstance(value, float) or cmp_type == self.FLOAT:
result = self.FLOAT

# The value is very likely an int
# let's get its size
# If the compare types are empty and use the types of the current value
if type_lit == int and cmp_type in (self.INT_TYPES + [None, "", self.BOOL]):

elif isinstance(value, int) and cmp_type in (
self.INT_TYPES + [None, "", self.BOOL]
):
# Use smallest possible int type above TINYINT
if self.SMALLINT_MIN < val_lit < self.SMALLINT_MAX:
return self.get_bigger_int(self.SMALLINT, cmp_type)
elif self.MEDIUMINT_MIN < val_lit < self.MEDIUMINT_MAX:
return self.get_bigger_int(self.MEDIUMINT, cmp_type)
elif self.INT_MIN < val_lit < self.INT_MAX:
return self.get_bigger_int(self.INT, cmp_type)
if self.SMALLINT_MIN < value < self.SMALLINT_MAX:
result = self.get_bigger_int(self.SMALLINT, cmp_type)
elif self.MEDIUMINT_MIN < value < self.MEDIUMINT_MAX:
result = self.get_bigger_int(self.MEDIUMINT, cmp_type)
elif self.INT_MIN < value < self.INT_MAX:
result = self.get_bigger_int(self.INT, cmp_type)
else:
return self.BIGINT
result = self.BIGINT

else:
# Need to determine who makes it all the way down here
logger.debug(f"Unexpected object type: {type(value)}")
result = cmp_type

# Need to determine who makes it all the way down here
return cmp_type
return result

def format_column(self, col, index="", replace_chars=None, col_prefix="_"):
"""Format the column to meet database contraints.
Expand Down
52 changes: 8 additions & 44 deletions test/test_databases/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
MEDIUMINT,
INT,
BIGINT,
FLOAT,
BOOL,
VARCHAR,
)
Expand All @@ -19,13 +18,6 @@ def dcs():
return db


@pytest.fixture
def dcs_bool():
db = DatabaseCreateStatement()
db.DO_PARSE_BOOLS = True
return db


@pytest.mark.parametrize(
("int1", "int2", "higher"),
(
Expand Down Expand Up @@ -95,34 +87,6 @@ def test_is_valid_sql_num(dcs, val, is_valid):
assert dcs.is_valid_sql_num(val) == is_valid


@pytest.mark.parametrize(
("val", "cmp_type", "detected_type"),
(
(1, None, SMALLINT),
(1, "", SMALLINT),
(1, MEDIUMINT, MEDIUMINT),
(32769, None, MEDIUMINT),
(32769, BIGINT, BIGINT),
(2147483648, None, BIGINT),
(2147483648, FLOAT, FLOAT),
(5.001, None, FLOAT),
(5.001, "", FLOAT),
("FALSE", VARCHAR, VARCHAR),
("word", "", VARCHAR),
("word", INT, VARCHAR),
("1_2", BIGINT, VARCHAR),
("01", FLOAT, VARCHAR),
("00001", None, VARCHAR),
("word", None, VARCHAR),
("1_2", None, VARCHAR),
("01", None, VARCHAR),
("{}", None, VARCHAR),
),
)
def test_detect_data_type(dcs, val, cmp_type, detected_type):
assert dcs.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
("val", "cmp_type", "detected_type"),
(
Expand All @@ -131,16 +95,16 @@ def test_detect_data_type(dcs, val, cmp_type, detected_type):
(1, MEDIUMINT, MEDIUMINT),
(2, BOOL, SMALLINT),
(True, None, BOOL),
(0, None, BOOL),
(1, None, BOOL),
(1, BOOL, BOOL),
("F", None, BOOL),
("FALSE", None, BOOL),
("Yes", None, BOOL),
(0, None, SMALLINT),
(1, None, SMALLINT),
(1, BOOL, SMALLINT),
("F", None, VARCHAR),
("FALSE", None, VARCHAR),
("Yes", None, VARCHAR),
),
)
def test_detect_data_type_bools(dcs_bool, val, cmp_type, detected_type):
assert dcs_bool.detect_data_type(val, cmp_type) == detected_type
def test_detect_data_type_bools(dcs, val, cmp_type, detected_type):
assert dcs.detect_data_type(val, cmp_type) == detected_type


@pytest.mark.parametrize(
Expand Down
10 changes: 4 additions & 6 deletions test/test_databases/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,9 @@ def setUp(self):
def test_data_type(self):

# Test bool
self.mysql.DO_PARSE_BOOLS = True
self.assertEqual(self.mysql.data_type(1, ""), "bool")
self.assertEqual(self.mysql.data_type(False, ""), "bool")
self.assertEqual(self.mysql.data_type(True, ""), "bool")

self.mysql.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.mysql.data_type(1, ""), "smallint")
self.assertEqual(self.mysql.data_type(2, ""), "smallint")
Expand All @@ -170,14 +168,14 @@ def test_data_type(self):
self.assertEqual(self.mysql.data_type(2147483648, ""), "bigint")
# Test varchar that looks like an int
self.assertEqual(self.mysql.data_type("00001", ""), "varchar")
# Test varchar that looks like a bool
self.assertEqual(self.mysql.data_type(False, ""), "varchar")
# Test a float as a decimal
self.assertEqual(self.mysql.data_type(5.001, ""), "float")
# Test varchar
self.assertEqual(self.mysql.data_type("word", ""), "varchar")
# Test int with underscore
# Test int with underscore as string
self.assertEqual(self.mysql.data_type("1_2", ""), "varchar")
# Test int with underscore
self.assertEqual(self.mysql.data_type(1_2, ""), "smallint")
# Test int with leading zero
self.assertEqual(self.mysql.data_type("01", ""), "varchar")

Expand Down
18 changes: 4 additions & 14 deletions test/test_databases/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ def setUp(self):
["g", "", 9, "NA", 1.4, 1, 2],
]
)

self.mapping = self.pg.generate_data_types(self.tbl)
self.mapping2 = self.pg.generate_data_types(self.tbl2)
self.pg.DO_PARSE_BOOLS = True
self.mapping3 = self.pg.generate_data_types(self.tbl2)

def test_connection(self):

Expand All @@ -56,7 +53,6 @@ def test_connection(self):
self.assertEqual(pg_env.port, 5432)

def test_data_type(self):
self.pg.DO_PARSE_BOOLS = False
# Test smallint
self.assertEqual(self.pg.data_type(1, ""), "smallint")
self.assertEqual(self.pg.data_type(2, ""), "smallint")
Expand All @@ -66,20 +62,18 @@ def test_data_type(self):
self.assertEqual(self.pg.data_type(2147483648, ""), "bigint")
# Test varchar that looks like an int
self.assertEqual(self.pg.data_type("00001", ""), "varchar")
# Test varchar that looks like a bool
self.assertEqual(self.pg.data_type(True, ""), "varchar")
# Test a float as a decimal
self.assertEqual(self.pg.data_type(5.001, ""), "decimal")
# Test varchar
self.assertEqual(self.pg.data_type("word", ""), "varchar")
# Test int with underscore
# Test int with underscore as string
self.assertEqual(self.pg.data_type("1_2", ""), "varchar")
# Test int with leading zero
# Test int with leading zero as string
self.assertEqual(self.pg.data_type("01", ""), "varchar")
# Test int with underscore
self.assertEqual(self.pg.data_type(1_2, ""), "smallint")

# Test bool
self.pg.DO_PARSE_BOOLS = True
self.assertEqual(self.pg.data_type(1, ""), "bool")
self.assertEqual(self.pg.data_type(True, ""), "bool")

def test_generate_data_types(self):
Expand All @@ -100,10 +94,6 @@ def test_generate_data_types(self):
"varchar",
],
)
self.assertEqual(
self.mapping3["type_list"],
["varchar", "varchar", "decimal", "varchar", "decimal", "bool", "varchar"],
)
# Test correct lengths
self.assertEqual(self.mapping["longest"], [1, 5])

Expand Down
19 changes: 4 additions & 15 deletions test/test_redshift.py → test/test_databases/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def setUp(self):
)

self.mapping = self.rs.generate_data_types(self.tbl)
self.rs.DO_PARSE_BOOLS = True
self.mapping2 = self.rs.generate_data_types(self.tbl2)
self.rs.DO_PARSE_BOOLS = False
self.mapping3 = self.rs.generate_data_types(self.tbl2)

def test_split_full_table_name(self):
schema, table = Redshift.split_full_table_name("some_schema.some_table")
Expand All @@ -60,14 +57,9 @@ def test_combine_schema_and_table_name(self):
self.assertEqual(full_table_name, "some_schema.some_table")

def test_data_type(self):

# Test bool
self.rs.DO_PARSE_BOOLS = True
self.assertEqual(self.rs.data_type(1, ""), "bool")
self.assertEqual(self.rs.data_type(True, ""), "bool")
self.rs.DO_PARSE_BOOLS = False
self.assertEqual(self.rs.data_type(1, ""), "int")
self.assertEqual(self.rs.data_type(True, ""), "varchar")
# Test smallint
# Currently smallints are coded as ints
self.assertEqual(self.rs.data_type(2, ""), "int")
Expand All @@ -81,27 +73,24 @@ def test_data_type(self):
self.assertEqual(self.rs.data_type(5.001, ""), "float")
# Test varchar
self.assertEqual(self.rs.data_type("word", ""), "varchar")
# Test int with underscore
# Test int with underscore as varchar
self.assertEqual(self.rs.data_type("1_2", ""), "varchar")
# Test int with underscore
self.assertEqual(self.rs.data_type(1_2, ""), "int")
# Test int with leading zero
self.assertEqual(self.rs.data_type("01", ""), "varchar")

def test_generate_data_types(self):

# Test correct header labels
self.assertEqual(self.mapping["headers"], ["ID", "Name"])
# Test correct data types
self.assertEqual(self.mapping["type_list"], ["int", "varchar"])

self.assertEqual(
self.mapping2["type_list"],
["varchar", "varchar", "float", "varchar", "float", "bool", "varchar"],
)

self.assertEqual(
self.mapping3["type_list"],
["varchar", "varchar", "float", "varchar", "float", "int", "varchar"],
)

# Test correct lengths
self.assertEqual(self.mapping["longest"], [1, 5])

Expand Down

0 comments on commit 14df184

Please sign in to comment.