From 14df184bbbef840da6536bd25fc99c443fed5f5d Mon Sep 17 00:00:00 2001 From: Austin Weisgrau <62900254+austinweisgrau@users.noreply.github.com> Date: Fri, 8 Dec 2023 12:15:10 -0800 Subject: [PATCH] Parse Boolean types by default (#943) * Parse Boolean types by default Commit 766cfaedc5652af8c39fde890067a99b6fa58518 created a feature for parsing boolean types but turned it off by default. This commit turns that feature on by default and adds a comment about how to turn it off and what that does. * Fix test expectations after updating boolean parsing behavior * Only ever interpret python bools as SQL booleans No longer coerce by default any of the following as booleans: "yes", "True", "t", 1, 0, "no", "False", "f" * Fix redshift test parsing bools * Move redshift test into test_databases folder * Remove retired TRUE_VALS and FALSE_VALS configuration variables We now only use python booleans --- parsons/databases/database/constants.py | 4 - parsons/databases/database/database.py | 104 ++++++--------------- test/test_databases/test_database.py | 52 ++--------- test/test_databases/test_mysql.py | 10 +- test/test_databases/test_postgres.py | 18 +--- test/{ => test_databases}/test_redshift.py | 19 +--- 6 files changed, 50 insertions(+), 157 deletions(-) rename test/{ => test_databases}/test_redshift.py (98%) diff --git a/parsons/databases/database/constants.py b/parsons/databases/database/constants.py index 8935a8734d..1b78ffb0af 100644 --- a/parsons/databases/database/constants.py +++ b/parsons/databases/database/constants.py @@ -153,11 +153,7 @@ VARCHAR = "varchar" FLOAT = "float" - -DO_PARSE_BOOLS = False BOOL = "bool" -TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1) -FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0) # The following values are the minimum and maximum values for MySQL int # types. https://dev.mysql.com/doc/refman/8.0/en/integer-types.html diff --git a/parsons/databases/database/database.py b/parsons/databases/database/database.py index 9f369d43c9..bfec754241 100644 --- a/parsons/databases/database/database.py +++ b/parsons/databases/database/database.py @@ -1,5 +1,7 @@ import parsons.databases.database.constants as consts -import ast +import logging + +logger = logging.getLogger(__name__) class DatabaseCreateStatement: @@ -17,11 +19,7 @@ def __init__(self): self.BIGINT = consts.BIGINT self.FLOAT = consts.FLOAT - # Added for backwards compatability - self.DO_PARSE_BOOLS = consts.DO_PARSE_BOOLS self.BOOL = consts.BOOL - self.TRUE_VALS = consts.TRUE_VALS - self.FALSE_VALS = consts.FALSE_VALS self.VARCHAR = consts.VARCHAR self.RESERVED_WORDS = consts.RESERVED_WORDS @@ -117,29 +115,6 @@ def is_valid_sql_num(self, val): except (TypeError, ValueError): return False - def is_sql_bool(self, val): - """Check whether val is a valid sql boolean. - - When inserting data into databases, different values can be accepted - as boolean types. For excample, ``False``, ``'FALSE'``, ``1``. - - `Args`: - val: any - The value to check. - `Returns`: - bool - Whether or not the value is a valid sql boolean. - """ - if not self.DO_PARSE_BOOLS: - return - - if isinstance(val, bool) or ( - type(val) in (int, str) - and str(val).upper() in self.TRUE_VALS + self.FALSE_VALS - ): - return True - return False - def detect_data_type(self, value, cmp_type=None): """Detect the higher of value's type cmp_type. @@ -161,64 +136,45 @@ def detect_data_type(self, value, cmp_type=None): # Stop if the compare type is already a varchar # varchar is the highest data type. if cmp_type == self.VARCHAR: - return cmp_type - - # Attempt to evaluate value as a literal (e.g. '1' => 1, ) If the value - # is just a string, is None, or is empty, it will raise an error. These - # should be considered varchars. - # E.g. - # "" => SyntaxError - # "anystring" => ValueError - try: - val_lit = ast.literal_eval(str(value)) - except (SyntaxError, ValueError): - if self.is_sql_bool(value): - return self.BOOL - return self.VARCHAR - - # Exit early if it's None - # is_valid_sql_num(None) == False - # instead of defaulting to varchar (which is the next test) - # return the compare type - if val_lit is None: - return cmp_type + result = cmp_type + + elif isinstance(value, bool): + result = self.BOOL + + elif value is None: + result = cmp_type # Make sure that it is a valid integer # Python accepts 100_000 as a valid form of 100000, # however a sql engine may throw an error - if not self.is_valid_sql_num(value): - if self.is_sql_bool(val_lit) and cmp_type != self.VARCHAR: - return self.BOOL - else: - return self.VARCHAR + elif not self.is_valid_sql_num(value): + result = self.VARCHAR - if self.is_sql_bool(val_lit) and cmp_type not in self.INT_TYPES + [self.FLOAT]: - return self.BOOL - - type_lit = type(val_lit) - - # If a float, stop here - # float is highest after varchar - if type_lit == float or cmp_type == self.FLOAT: - return self.FLOAT + elif isinstance(value, float) or cmp_type == self.FLOAT: + result = self.FLOAT # The value is very likely an int # let's get its size # If the compare types are empty and use the types of the current value - if type_lit == int and cmp_type in (self.INT_TYPES + [None, "", self.BOOL]): - + elif isinstance(value, int) and cmp_type in ( + self.INT_TYPES + [None, "", self.BOOL] + ): # Use smallest possible int type above TINYINT - if self.SMALLINT_MIN < val_lit < self.SMALLINT_MAX: - return self.get_bigger_int(self.SMALLINT, cmp_type) - elif self.MEDIUMINT_MIN < val_lit < self.MEDIUMINT_MAX: - return self.get_bigger_int(self.MEDIUMINT, cmp_type) - elif self.INT_MIN < val_lit < self.INT_MAX: - return self.get_bigger_int(self.INT, cmp_type) + if self.SMALLINT_MIN < value < self.SMALLINT_MAX: + result = self.get_bigger_int(self.SMALLINT, cmp_type) + elif self.MEDIUMINT_MIN < value < self.MEDIUMINT_MAX: + result = self.get_bigger_int(self.MEDIUMINT, cmp_type) + elif self.INT_MIN < value < self.INT_MAX: + result = self.get_bigger_int(self.INT, cmp_type) else: - return self.BIGINT + result = self.BIGINT + + else: + # Need to determine who makes it all the way down here + logger.debug(f"Unexpected object type: {type(value)}") + result = cmp_type - # Need to determine who makes it all the way down here - return cmp_type + return result def format_column(self, col, index="", replace_chars=None, col_prefix="_"): """Format the column to meet database contraints. diff --git a/test/test_databases/test_database.py b/test/test_databases/test_database.py index b49491e687..c56116e3d1 100644 --- a/test/test_databases/test_database.py +++ b/test/test_databases/test_database.py @@ -3,7 +3,6 @@ MEDIUMINT, INT, BIGINT, - FLOAT, BOOL, VARCHAR, ) @@ -19,13 +18,6 @@ def dcs(): return db -@pytest.fixture -def dcs_bool(): - db = DatabaseCreateStatement() - db.DO_PARSE_BOOLS = True - return db - - @pytest.mark.parametrize( ("int1", "int2", "higher"), ( @@ -95,34 +87,6 @@ def test_is_valid_sql_num(dcs, val, is_valid): assert dcs.is_valid_sql_num(val) == is_valid -@pytest.mark.parametrize( - ("val", "cmp_type", "detected_type"), - ( - (1, None, SMALLINT), - (1, "", SMALLINT), - (1, MEDIUMINT, MEDIUMINT), - (32769, None, MEDIUMINT), - (32769, BIGINT, BIGINT), - (2147483648, None, BIGINT), - (2147483648, FLOAT, FLOAT), - (5.001, None, FLOAT), - (5.001, "", FLOAT), - ("FALSE", VARCHAR, VARCHAR), - ("word", "", VARCHAR), - ("word", INT, VARCHAR), - ("1_2", BIGINT, VARCHAR), - ("01", FLOAT, VARCHAR), - ("00001", None, VARCHAR), - ("word", None, VARCHAR), - ("1_2", None, VARCHAR), - ("01", None, VARCHAR), - ("{}", None, VARCHAR), - ), -) -def test_detect_data_type(dcs, val, cmp_type, detected_type): - assert dcs.detect_data_type(val, cmp_type) == detected_type - - @pytest.mark.parametrize( ("val", "cmp_type", "detected_type"), ( @@ -131,16 +95,16 @@ def test_detect_data_type(dcs, val, cmp_type, detected_type): (1, MEDIUMINT, MEDIUMINT), (2, BOOL, SMALLINT), (True, None, BOOL), - (0, None, BOOL), - (1, None, BOOL), - (1, BOOL, BOOL), - ("F", None, BOOL), - ("FALSE", None, BOOL), - ("Yes", None, BOOL), + (0, None, SMALLINT), + (1, None, SMALLINT), + (1, BOOL, SMALLINT), + ("F", None, VARCHAR), + ("FALSE", None, VARCHAR), + ("Yes", None, VARCHAR), ), ) -def test_detect_data_type_bools(dcs_bool, val, cmp_type, detected_type): - assert dcs_bool.detect_data_type(val, cmp_type) == detected_type +def test_detect_data_type_bools(dcs, val, cmp_type, detected_type): + assert dcs.detect_data_type(val, cmp_type) == detected_type @pytest.mark.parametrize( diff --git a/test/test_databases/test_mysql.py b/test/test_databases/test_mysql.py index 01c156dfe4..323b4ffbf6 100644 --- a/test/test_databases/test_mysql.py +++ b/test/test_databases/test_mysql.py @@ -156,11 +156,9 @@ def setUp(self): def test_data_type(self): # Test bool - self.mysql.DO_PARSE_BOOLS = True - self.assertEqual(self.mysql.data_type(1, ""), "bool") self.assertEqual(self.mysql.data_type(False, ""), "bool") + self.assertEqual(self.mysql.data_type(True, ""), "bool") - self.mysql.DO_PARSE_BOOLS = False # Test smallint self.assertEqual(self.mysql.data_type(1, ""), "smallint") self.assertEqual(self.mysql.data_type(2, ""), "smallint") @@ -170,14 +168,14 @@ def test_data_type(self): self.assertEqual(self.mysql.data_type(2147483648, ""), "bigint") # Test varchar that looks like an int self.assertEqual(self.mysql.data_type("00001", ""), "varchar") - # Test varchar that looks like a bool - self.assertEqual(self.mysql.data_type(False, ""), "varchar") # Test a float as a decimal self.assertEqual(self.mysql.data_type(5.001, ""), "float") # Test varchar self.assertEqual(self.mysql.data_type("word", ""), "varchar") - # Test int with underscore + # Test int with underscore as string self.assertEqual(self.mysql.data_type("1_2", ""), "varchar") + # Test int with underscore + self.assertEqual(self.mysql.data_type(1_2, ""), "smallint") # Test int with leading zero self.assertEqual(self.mysql.data_type("01", ""), "varchar") diff --git a/test/test_databases/test_postgres.py b/test/test_databases/test_postgres.py index 08956c672d..5279c94ccf 100644 --- a/test/test_databases/test_postgres.py +++ b/test/test_databases/test_postgres.py @@ -30,11 +30,8 @@ def setUp(self): ["g", "", 9, "NA", 1.4, 1, 2], ] ) - self.mapping = self.pg.generate_data_types(self.tbl) self.mapping2 = self.pg.generate_data_types(self.tbl2) - self.pg.DO_PARSE_BOOLS = True - self.mapping3 = self.pg.generate_data_types(self.tbl2) def test_connection(self): @@ -56,7 +53,6 @@ def test_connection(self): self.assertEqual(pg_env.port, 5432) def test_data_type(self): - self.pg.DO_PARSE_BOOLS = False # Test smallint self.assertEqual(self.pg.data_type(1, ""), "smallint") self.assertEqual(self.pg.data_type(2, ""), "smallint") @@ -66,20 +62,18 @@ def test_data_type(self): self.assertEqual(self.pg.data_type(2147483648, ""), "bigint") # Test varchar that looks like an int self.assertEqual(self.pg.data_type("00001", ""), "varchar") - # Test varchar that looks like a bool - self.assertEqual(self.pg.data_type(True, ""), "varchar") # Test a float as a decimal self.assertEqual(self.pg.data_type(5.001, ""), "decimal") # Test varchar self.assertEqual(self.pg.data_type("word", ""), "varchar") - # Test int with underscore + # Test int with underscore as string self.assertEqual(self.pg.data_type("1_2", ""), "varchar") - # Test int with leading zero + # Test int with leading zero as string self.assertEqual(self.pg.data_type("01", ""), "varchar") + # Test int with underscore + self.assertEqual(self.pg.data_type(1_2, ""), "smallint") # Test bool - self.pg.DO_PARSE_BOOLS = True - self.assertEqual(self.pg.data_type(1, ""), "bool") self.assertEqual(self.pg.data_type(True, ""), "bool") def test_generate_data_types(self): @@ -100,10 +94,6 @@ def test_generate_data_types(self): "varchar", ], ) - self.assertEqual( - self.mapping3["type_list"], - ["varchar", "varchar", "decimal", "varchar", "decimal", "bool", "varchar"], - ) # Test correct lengths self.assertEqual(self.mapping["longest"], [1, 5]) diff --git a/test/test_redshift.py b/test/test_databases/test_redshift.py similarity index 98% rename from test/test_redshift.py rename to test/test_databases/test_redshift.py index fd664fede5..b811e98857 100644 --- a/test/test_redshift.py +++ b/test/test_databases/test_redshift.py @@ -35,10 +35,7 @@ def setUp(self): ) self.mapping = self.rs.generate_data_types(self.tbl) - self.rs.DO_PARSE_BOOLS = True self.mapping2 = self.rs.generate_data_types(self.tbl2) - self.rs.DO_PARSE_BOOLS = False - self.mapping3 = self.rs.generate_data_types(self.tbl2) def test_split_full_table_name(self): schema, table = Redshift.split_full_table_name("some_schema.some_table") @@ -60,14 +57,9 @@ def test_combine_schema_and_table_name(self): self.assertEqual(full_table_name, "some_schema.some_table") def test_data_type(self): - # Test bool - self.rs.DO_PARSE_BOOLS = True - self.assertEqual(self.rs.data_type(1, ""), "bool") self.assertEqual(self.rs.data_type(True, ""), "bool") - self.rs.DO_PARSE_BOOLS = False self.assertEqual(self.rs.data_type(1, ""), "int") - self.assertEqual(self.rs.data_type(True, ""), "varchar") # Test smallint # Currently smallints are coded as ints self.assertEqual(self.rs.data_type(2, ""), "int") @@ -81,13 +73,14 @@ def test_data_type(self): self.assertEqual(self.rs.data_type(5.001, ""), "float") # Test varchar self.assertEqual(self.rs.data_type("word", ""), "varchar") - # Test int with underscore + # Test int with underscore as varchar self.assertEqual(self.rs.data_type("1_2", ""), "varchar") + # Test int with underscore + self.assertEqual(self.rs.data_type(1_2, ""), "int") # Test int with leading zero self.assertEqual(self.rs.data_type("01", ""), "varchar") def test_generate_data_types(self): - # Test correct header labels self.assertEqual(self.mapping["headers"], ["ID", "Name"]) # Test correct data types @@ -95,13 +88,9 @@ def test_generate_data_types(self): self.assertEqual( self.mapping2["type_list"], - ["varchar", "varchar", "float", "varchar", "float", "bool", "varchar"], - ) - - self.assertEqual( - self.mapping3["type_list"], ["varchar", "varchar", "float", "varchar", "float", "int", "varchar"], ) + # Test correct lengths self.assertEqual(self.mapping["longest"], [1, 5])