From ddfc51b4c2749b082fab893cdd65a522ab8fbfa4 Mon Sep 17 00:00:00 2001 From: Daniel Bravo Date: Thu, 25 Mar 2021 16:02:55 -0700 Subject: [PATCH 1/6] feat(Database): Add support for parsing booleans --- parsons/databases/database/constants.py | 4 ++ parsons/databases/database/database.py | 37 +++++++++++++++++-- parsons/databases/redshift/rs_create_table.py | 3 +- test/test_databases/test_database.py | 20 ++++++++-- test/test_databases/test_mysql.py | 11 ++++-- test/test_databases/test_postgres.py | 7 +++- test/test_redshift.py | 15 +++++--- 7 files changed, 79 insertions(+), 18 deletions(-) diff --git a/parsons/databases/database/constants.py b/parsons/databases/database/constants.py index 8a1d371c60..c879c1480e 100644 --- a/parsons/databases/database/constants.py +++ b/parsons/databases/database/constants.py @@ -27,6 +27,10 @@ VARCHAR = "varchar" FLOAT = "float" +BOOL = "bool" +TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1) +FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0) + # The following values are the minimum and maximum values for MySQL int # types. https://dev.mysql.com/doc/refman/8.0/en/integer-types.html SMALLINT = "smallint" diff --git a/parsons/databases/database/database.py b/parsons/databases/database/database.py index 2852734b61..16fbcb478b 100644 --- a/parsons/databases/database/database.py +++ b/parsons/databases/database/database.py @@ -17,6 +17,9 @@ def __init__(self): self.INT_MAX = consts.INT_MAX self.BIGINT = consts.BIGINT self.FLOAT = consts.FLOAT + self.BOOL = consts.BOOL + self.TRUE_VALS = consts.TRUE_VALS + self.FALSE_VALS = consts.FALSE_VALS self.VARCHAR = consts.VARCHAR self.RESERVED_WORDS = consts.RESERVED_WORDS self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN @@ -97,7 +100,8 @@ def is_valid_sql_num(self, val): # then it's a valid sql number # Also check the first character is not zero try: - if (float(val) or int(val)) and "_" not in val and val[0] != "0": + if ((float(val) or 1) and "_" not in val and + (val in ("0", "0.0") or val[0] != "0")): return True else: return False @@ -107,6 +111,25 @@ def is_valid_sql_num(self, val): except (TypeError, ValueError): return False + def is_sql_bool(self, val): + """Check whether val is a valid sql boolean. + + When inserting data into databases, different values can be accepted + as boolean types. For excample, ``False``, ``'FALSE'``, ``1``. + + `Args`: + val: any + The value to check. + `Returns`: + bool + Whether or not the value is a valid sql boolean. + """ + if isinstance(val, bool) or ( + type(val) in (int, str) and + str(val).upper() in self.TRUE_VALS + self.FALSE_VALS): + return True + return False + def detect_data_type(self, value, cmp_type=None): """Detect the higher of value's type cmp_type. @@ -139,6 +162,8 @@ def detect_data_type(self, value, cmp_type=None): try: val_lit = ast.literal_eval(str(value)) except (SyntaxError, ValueError): + if self.is_sql_bool(value): + return self.BOOL return self.VARCHAR # Exit early if it's None @@ -152,7 +177,13 @@ def detect_data_type(self, value, cmp_type=None): # Python accepts 100_000 as a valid form of 100000, # however a sql engine may throw an error if not self.is_valid_sql_num(value): - return self.VARCHAR + if self.is_sql_bool(val_lit) and cmp_type != self.VARCHAR: + return self.BOOL + else: + return self.VARCHAR + + if self.is_sql_bool(val_lit) and cmp_type not in self.INT_TYPES + [self.FLOAT]: + return self.BOOL type_lit = type(val_lit) @@ -164,7 +195,7 @@ def detect_data_type(self, value, cmp_type=None): # The value is very likely an int # let's get its size # If the compare types are empty and use the types of the current value - if type_lit == int and cmp_type in (self.INT_TYPES + [None, ""]): + if type_lit == int and cmp_type in (self.INT_TYPES + [None, "", self.BOOL]): # Use smallest possible int type above TINYINT if (self.SMALLINT_MIN < val_lit < self.SMALLINT_MAX): diff --git a/parsons/databases/redshift/rs_create_table.py b/parsons/databases/redshift/rs_create_table.py index 84571108a8..57808423f3 100644 --- a/parsons/databases/redshift/rs_create_table.py +++ b/parsons/databases/redshift/rs_create_table.py @@ -15,8 +15,9 @@ def __init__(self): self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN self.REPLACE_CHARS = consts.REPLACE_CHARS - # Redshift doesn't have a medium int + # Currently smallints are coded as ints self.SMALLINT = self.INT + # Redshift doesn't have a medium int self.MEDIUMINT = self.INT # Currently py floats are coded as Redshift decimals diff --git a/test/test_databases/test_database.py b/test/test_databases/test_database.py index f3de5fe5db..a6818c906c 100644 --- a/test/test_databases/test_database.py +++ b/test/test_databases/test_database.py @@ -1,5 +1,5 @@ from parsons.databases.database.constants import ( - SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, VARCHAR) + SMALLINT, MEDIUMINT, INT, BIGINT, FLOAT, BOOL, VARCHAR) from parsons.databases.database.database import DatabaseCreateStatement @@ -49,6 +49,8 @@ def test_get_bigger_int(dcs, int1, int2, higher): (+1.2, True), (+1., True), (+1.0_0, True), + (0, True), + (0.0, True), ("10", True), ("1_0", False), ("+10", True), @@ -60,6 +62,8 @@ def test_get_bigger_int(dcs, int1, int2, higher): ("+1.2", True), ("+1.", True), ("+1.0_0", False), + ("0", True), + ("0.0", True), (True, False), ("True", False), ("a string", False), @@ -74,8 +78,8 @@ def test_is_valid_sql_num(dcs, val, is_valid): @pytest.mark.parametrize( ("val", "cmp_type", "detected_type"), - ((1, None, SMALLINT), - (1, "", SMALLINT), + ((2, None, SMALLINT), + (2, "", SMALLINT), (1, MEDIUMINT, MEDIUMINT), (32769, None, MEDIUMINT), (32769, BIGINT, BIGINT), @@ -83,6 +87,15 @@ def test_is_valid_sql_num(dcs, val, is_valid): (2147483648, FLOAT, FLOAT), (5.001, None, FLOAT), (5.001, "", FLOAT), + (2, BOOL, SMALLINT), + (True, None, BOOL), + (0, None, BOOL), + (1, None, BOOL), + (1, BOOL, BOOL), + ("F", None, BOOL), + ("FALSE", None, BOOL), + ("Yes", None, BOOL), + ("FALSE", VARCHAR, VARCHAR), ("word", "", VARCHAR), ("word", INT, VARCHAR), ("1_2", BIGINT, VARCHAR), @@ -91,6 +104,7 @@ def test_is_valid_sql_num(dcs, val, is_valid): ("word", None, VARCHAR), ("1_2", None, VARCHAR), ("01", None, VARCHAR), + ("{}", None, VARCHAR), )) def test_detect_data_type(dcs, val, cmp_type, detected_type): assert dcs.detect_data_type(val, cmp_type) == detected_type diff --git a/test/test_databases/test_mysql.py b/test/test_databases/test_mysql.py index 965a97f961..fb8fb2d75f 100644 --- a/test/test_databases/test_mysql.py +++ b/test/test_databases/test_mysql.py @@ -1,5 +1,5 @@ from parsons.databases.mysql.mysql import MySQL -from parsons.databases.table import BaseTable +from parsons.databases.mysql.create_table import MySQLCreateTable from parsons.etl.table import Table from test.utils import assert_matching_tables import unittest @@ -64,7 +64,7 @@ def setUp(self): ('you', 'hey', '3') """) - self.tbl = BaseTable(self.mysql, 'test') + self.tbl = MySQLCreateTable(self.mysql, 'test') def tearDown(self): @@ -91,7 +91,7 @@ def test_exists(self): self.assertTrue(self.tbl.exists) - tbl_bad = BaseTable(self.mysql, 'bad_test') + tbl_bad = MySQLCreateTable(self.mysql, 'bad_test') self.assertFalse(tbl_bad.exists) def test_drop(self): @@ -145,8 +145,11 @@ def setUp(self): def test_data_type(self): + # Test bool + self.assertEqual(self.mysql.data_type(1, ''), 'bool') + self.assertEqual(self.mysql.data_type(False, ''), 'bool') # Test smallint - self.assertEqual(self.mysql.data_type(1, ''), 'smallint') + self.assertEqual(self.mysql.data_type(2, ''), 'smallint') # Test int self.assertEqual(self.mysql.data_type(32769, ''), 'mediumint') # Test bigint diff --git a/test/test_databases/test_postgres.py b/test/test_databases/test_postgres.py index 0effdb2441..b5acd9f050 100644 --- a/test/test_databases/test_postgres.py +++ b/test/test_databases/test_postgres.py @@ -56,8 +56,11 @@ def test_connection(self): def test_data_type(self): + # Test bool + self.assertEqual(self.pg.data_type(1, ''), 'bool') + self.assertEqual(self.pg.data_type(True, ''), 'bool') # Test smallint - self.assertEqual(self.pg.data_type(1, ''), 'smallint') + self.assertEqual(self.pg.data_type(2, ''), 'smallint') # Test int self.assertEqual(self.pg.data_type(32769, ''), 'int') # Test bigint @@ -81,7 +84,7 @@ def test_generate_data_types(self): self.assertEqual(self.mapping['type_list'], ['smallint', 'varchar']) self.assertEqual( self.mapping2['type_list'], - ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "smallint", "varchar"]) + ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "bool", "varchar"]) # Test correct lengths self.assertEqual(self.mapping['longest'], [1, 5]) diff --git a/test/test_redshift.py b/test/test_redshift.py index 372f6cfa31..c9dcf87463 100644 --- a/test/test_redshift.py +++ b/test/test_redshift.py @@ -58,8 +58,12 @@ def test_combine_schema_and_table_name(self): def test_data_type(self): - # Test int - self.assertEqual(self.rs.data_type(1, ''), 'int') + # Test bool + self.assertEqual(self.rs.data_type(1, ''), 'bool') + self.assertEqual(self.rs.data_type(True, ''), 'bool') + # Test smallint + # Currently smallints are coded as ints + self.assertEqual(self.rs.data_type(2, ''), 'int') # Test int self.assertEqual(self.rs.data_type(32769, ''), 'int') # Test bigint @@ -84,7 +88,7 @@ def test_generate_data_types(self): self.assertEqual( self.mapping2['type_list'], - ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "int", "varchar"]) + ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "bool", "varchar"]) # Test correct lengths self.assertEqual(self.mapping['longest'], [1, 5]) @@ -120,8 +124,9 @@ def test_column_validate(self): bad_cols = ['a', 'a', '', 'SELECT', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaksdfjklasj' 'dfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjklasdfjklasjkl' 'dfakljsdfjalsdkfjklasjdfklasjdfklasdkljf'] - fixed_cols = ['a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks' - 'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl'] + fixed_cols = [ + 'a', 'a_1', 'col_2', 'col_3', 'asdfjkasjdfklasjdfklajskdfljaskldfjaklsdfjlaks' + 'dfjklasjdfklasjdkfljaskldfljkasjdkfasjlkdfjklasdfjklakjsfasjkdfljaslkdfjkl'] self.assertEqual(self.rs.column_name_validate(bad_cols), fixed_cols) def test_create_statement(self): From c44729cda48dbdfbfa2f760bf3474ce68135c03e Mon Sep 17 00:00:00 2001 From: Daniel Bravo Date: Thu, 25 Mar 2021 17:44:14 -0700 Subject: [PATCH 2/6] Add attribute to set/unset bool parsing --- parsons/databases/database/database.py | 7 +++++++ test/test_databases/test_database.py | 4 +++- test/test_databases/test_mysql.py | 1 + test/test_databases/test_postgres.py | 1 + test/test_redshift.py | 1 + 5 files changed, 13 insertions(+), 1 deletion(-) diff --git a/parsons/databases/database/database.py b/parsons/databases/database/database.py index 16fbcb478b..3fd788569d 100644 --- a/parsons/databases/database/database.py +++ b/parsons/databases/database/database.py @@ -17,9 +17,13 @@ def __init__(self): self.INT_MAX = consts.INT_MAX self.BIGINT = consts.BIGINT self.FLOAT = consts.FLOAT + + # Added for backwards compatability + self.DO_PARSE_BOOLS = False self.BOOL = consts.BOOL self.TRUE_VALS = consts.TRUE_VALS self.FALSE_VALS = consts.FALSE_VALS + self.VARCHAR = consts.VARCHAR self.RESERVED_WORDS = consts.RESERVED_WORDS self.COL_NAME_MAX_LEN = consts.COL_NAME_MAX_LEN @@ -124,6 +128,9 @@ def is_sql_bool(self, val): bool Whether or not the value is a valid sql boolean. """ + if not self.DO_PARSE_BOOLS: + return + if isinstance(val, bool) or ( type(val) in (int, str) and str(val).upper() in self.TRUE_VALS + self.FALSE_VALS): diff --git a/test/test_databases/test_database.py b/test/test_databases/test_database.py index a6818c906c..869fce594e 100644 --- a/test/test_databases/test_database.py +++ b/test/test_databases/test_database.py @@ -8,7 +8,9 @@ @pytest.fixture def dcs(): - return DatabaseCreateStatement() + db = DatabaseCreateStatement() + db.DO_PARSE_BOOLS = True + return db @pytest.mark.parametrize( diff --git a/test/test_databases/test_mysql.py b/test/test_databases/test_mysql.py index fb8fb2d75f..700ec152e0 100644 --- a/test/test_databases/test_mysql.py +++ b/test/test_databases/test_mysql.py @@ -137,6 +137,7 @@ class TestMySQL(unittest.TestCase): # noqa def setUp(self): self.mysql = MySQL(username='test', password='test', host='test', db='test', port=123) + self.mysql.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name', 'Score'], [1, 'Jim', 1.9], diff --git a/test/test_databases/test_postgres.py b/test/test_databases/test_postgres.py index b5acd9f050..b8b9871cb1 100644 --- a/test/test_databases/test_postgres.py +++ b/test/test_databases/test_postgres.py @@ -15,6 +15,7 @@ class TestPostgresCreateStatement(unittest.TestCase): def setUp(self): self.pg = Postgres(username='test', password='test', host='test', db='test', port=123) + self.pg.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name'], [1, 'Jim'], diff --git a/test/test_redshift.py b/test/test_redshift.py index c9dcf87463..238b025518 100644 --- a/test/test_redshift.py +++ b/test/test_redshift.py @@ -19,6 +19,7 @@ class TestRedshift(unittest.TestCase): def setUp(self): self.rs = Redshift(username='test', password='test', host='test', db='test', port=123) + self.rs.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name'], [1, 'Jim'], From 38eab97143633ebf88d95012a670738025ea37d5 Mon Sep 17 00:00:00 2001 From: Daniel Bravo Date: Thu, 25 Mar 2021 17:56:46 -0700 Subject: [PATCH 3/6] Update to get DO_PARSE_BOOLS from constants file --- parsons/databases/database/constants.py | 1 + parsons/databases/database/database.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/parsons/databases/database/constants.py b/parsons/databases/database/constants.py index c879c1480e..eeb8a5513f 100644 --- a/parsons/databases/database/constants.py +++ b/parsons/databases/database/constants.py @@ -27,6 +27,7 @@ VARCHAR = "varchar" FLOAT = "float" +DO_PARSE_BOOLS= False BOOL = "bool" TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1) FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0) diff --git a/parsons/databases/database/database.py b/parsons/databases/database/database.py index 3fd788569d..1b71726266 100644 --- a/parsons/databases/database/database.py +++ b/parsons/databases/database/database.py @@ -19,7 +19,7 @@ def __init__(self): self.FLOAT = consts.FLOAT # Added for backwards compatability - self.DO_PARSE_BOOLS = False + self.DO_PARSE_BOOLS = consts.DO_PARSE_BOOLS self.BOOL = consts.BOOL self.TRUE_VALS = consts.TRUE_VALS self.FALSE_VALS = consts.FALSE_VALS From e3778ffc0acb323127f5a9902cd7221cf6f5f7c2 Mon Sep 17 00:00:00 2001 From: Daniel Bravo Date: Thu, 25 Mar 2021 18:28:07 -0700 Subject: [PATCH 4/6] Fix lint errors --- parsons/databases/database/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parsons/databases/database/constants.py b/parsons/databases/database/constants.py index eeb8a5513f..ec707ba455 100644 --- a/parsons/databases/database/constants.py +++ b/parsons/databases/database/constants.py @@ -27,7 +27,7 @@ VARCHAR = "varchar" FLOAT = "float" -DO_PARSE_BOOLS= False +DO_PARSE_BOOLS = False BOOL = "bool" TRUE_VALS = ("TRUE", "T", "YES", "Y", "1", 1) FALSE_VALS = ("FALSE", "F", "NO", "N", "0", 0) From 5e0fc12a488735043316b5653583fcb096f35c62 Mon Sep 17 00:00:00 2001 From: Chris Cuellar <58723+ChrisC@users.noreply.github.com> Date: Thu, 16 Dec 2021 15:11:15 -0800 Subject: [PATCH 5/6] fix merge from main --- test/test_redshift.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_redshift.py b/test/test_redshift.py index 7a4e9201c0..f2c2f70e98 100644 --- a/test/test_redshift.py +++ b/test/test_redshift.py @@ -89,7 +89,7 @@ def test_generate_data_types(self): self.assertEqual( self.mapping2['type_list'], - ['varchar', 'varchar', 'decimal', 'varchar', "float", "int", "bool", "varchar"]) + ['varchar', 'varchar', 'float', 'varchar', 'float', 'bool', 'varchar']) # Test correct lengths self.assertEqual(self.mapping['longest'], [1, 5]) From 2dc00b8c3361de7d0cdd52092c9b9cf7ff2c2779 Mon Sep 17 00:00:00 2001 From: Chris Cuellar <58723+ChrisC@users.noreply.github.com> Date: Thu, 16 Dec 2021 15:42:45 -0800 Subject: [PATCH 6/6] explicitly set boolean parsing in DB tests --- test/test_databases/test_database.py | 36 ++++++++++++++++++++-------- test/test_databases/test_mysql.py | 7 +++++- test/test_databases/test_postgres.py | 19 +++++++++++---- test/test_redshift.py | 12 +++++++++- 4 files changed, 57 insertions(+), 17 deletions(-) diff --git a/test/test_databases/test_database.py b/test/test_databases/test_database.py index 869fce594e..93ac82fd49 100644 --- a/test/test_databases/test_database.py +++ b/test/test_databases/test_database.py @@ -8,6 +8,12 @@ @pytest.fixture def dcs(): + db = DatabaseCreateStatement() + return db + + +@pytest.fixture +def dcs_bool(): db = DatabaseCreateStatement() db.DO_PARSE_BOOLS = True return db @@ -80,8 +86,8 @@ def test_is_valid_sql_num(dcs, val, is_valid): @pytest.mark.parametrize( ("val", "cmp_type", "detected_type"), - ((2, None, SMALLINT), - (2, "", SMALLINT), + ((1, None, SMALLINT), + (1, "", SMALLINT), (1, MEDIUMINT, MEDIUMINT), (32769, None, MEDIUMINT), (32769, BIGINT, BIGINT), @@ -89,14 +95,6 @@ def test_is_valid_sql_num(dcs, val, is_valid): (2147483648, FLOAT, FLOAT), (5.001, None, FLOAT), (5.001, "", FLOAT), - (2, BOOL, SMALLINT), - (True, None, BOOL), - (0, None, BOOL), - (1, None, BOOL), - (1, BOOL, BOOL), - ("F", None, BOOL), - ("FALSE", None, BOOL), - ("Yes", None, BOOL), ("FALSE", VARCHAR, VARCHAR), ("word", "", VARCHAR), ("word", INT, VARCHAR), @@ -112,6 +110,24 @@ def test_detect_data_type(dcs, val, cmp_type, detected_type): assert dcs.detect_data_type(val, cmp_type) == detected_type +@pytest.mark.parametrize( + ("val", "cmp_type", "detected_type"), + ((2, None, SMALLINT), + (2, "", SMALLINT), + (1, MEDIUMINT, MEDIUMINT), + (2, BOOL, SMALLINT), + (True, None, BOOL), + (0, None, BOOL), + (1, None, BOOL), + (1, BOOL, BOOL), + ("F", None, BOOL), + ("FALSE", None, BOOL), + ("Yes", None, BOOL) + )) +def test_detect_data_type_bools(dcs_bool, val, cmp_type, detected_type): + assert dcs_bool.detect_data_type(val, cmp_type) == detected_type + + @pytest.mark.parametrize( ("col", "renamed"), (("a", "a"), diff --git a/test/test_databases/test_mysql.py b/test/test_databases/test_mysql.py index 700ec152e0..361a3c229e 100644 --- a/test/test_databases/test_mysql.py +++ b/test/test_databases/test_mysql.py @@ -137,7 +137,6 @@ class TestMySQL(unittest.TestCase): # noqa def setUp(self): self.mysql = MySQL(username='test', password='test', host='test', db='test', port=123) - self.mysql.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name', 'Score'], [1, 'Jim', 1.9], @@ -147,9 +146,13 @@ def setUp(self): def test_data_type(self): # Test bool + self.mysql.DO_PARSE_BOOLS = True self.assertEqual(self.mysql.data_type(1, ''), 'bool') self.assertEqual(self.mysql.data_type(False, ''), 'bool') + + self.mysql.DO_PARSE_BOOLS = False # Test smallint + self.assertEqual(self.mysql.data_type(1, ''), 'smallint') self.assertEqual(self.mysql.data_type(2, ''), 'smallint') # Test int self.assertEqual(self.mysql.data_type(32769, ''), 'mediumint') @@ -157,6 +160,8 @@ def test_data_type(self): self.assertEqual(self.mysql.data_type(2147483648, ''), 'bigint') # Test varchar that looks like an int self.assertEqual(self.mysql.data_type('00001', ''), 'varchar') + # Test varchar that looks like a bool + self.assertEqual(self.mysql.data_type(False, ''), 'varchar') # Test a float as a decimal self.assertEqual(self.mysql.data_type(5.001, ''), 'float') # Test varchar diff --git a/test/test_databases/test_postgres.py b/test/test_databases/test_postgres.py index b8b9871cb1..44cf60d0fb 100644 --- a/test/test_databases/test_postgres.py +++ b/test/test_databases/test_postgres.py @@ -15,7 +15,6 @@ class TestPostgresCreateStatement(unittest.TestCase): def setUp(self): self.pg = Postgres(username='test', password='test', host='test', db='test', port=123) - self.pg.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name'], [1, 'Jim'], @@ -35,6 +34,8 @@ def setUp(self): self.mapping = self.pg.generate_data_types(self.tbl) self.mapping2 = self.pg.generate_data_types(self.tbl2) + self.pg.DO_PARSE_BOOLS = True + self.mapping3 = self.pg.generate_data_types(self.tbl2) def test_connection(self): @@ -56,11 +57,9 @@ def test_connection(self): self.assertEqual(pg_env.port, 5432) def test_data_type(self): - - # Test bool - self.assertEqual(self.pg.data_type(1, ''), 'bool') - self.assertEqual(self.pg.data_type(True, ''), 'bool') + self.pg.DO_PARSE_BOOLS = False # Test smallint + self.assertEqual(self.pg.data_type(1, ''), 'smallint') self.assertEqual(self.pg.data_type(2, ''), 'smallint') # Test int self.assertEqual(self.pg.data_type(32769, ''), 'int') @@ -68,6 +67,8 @@ def test_data_type(self): self.assertEqual(self.pg.data_type(2147483648, ''), 'bigint') # Test varchar that looks like an int self.assertEqual(self.pg.data_type('00001', ''), 'varchar') + # Test varchar that looks like a bool + self.assertEqual(self.pg.data_type(True, ''), 'varchar') # Test a float as a decimal self.assertEqual(self.pg.data_type(5.001, ''), 'decimal') # Test varchar @@ -77,6 +78,11 @@ def test_data_type(self): # Test int with leading zero self.assertEqual(self.pg.data_type('01', ''), 'varchar') + # Test bool + self.pg.DO_PARSE_BOOLS = True + self.assertEqual(self.pg.data_type(1, ''), 'bool') + self.assertEqual(self.pg.data_type(True, ''), 'bool') + def test_generate_data_types(self): # Test correct header labels @@ -85,6 +91,9 @@ def test_generate_data_types(self): self.assertEqual(self.mapping['type_list'], ['smallint', 'varchar']) self.assertEqual( self.mapping2['type_list'], + ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "smallint", "varchar"]) + self.assertEqual( + self.mapping3['type_list'], ['varchar', 'varchar', 'decimal', 'varchar', "decimal", "bool", "varchar"]) # Test correct lengths self.assertEqual(self.mapping['longest'], [1, 5]) diff --git a/test/test_redshift.py b/test/test_redshift.py index f2c2f70e98..8eddc58ebf 100644 --- a/test/test_redshift.py +++ b/test/test_redshift.py @@ -19,7 +19,6 @@ class TestRedshift(unittest.TestCase): def setUp(self): self.rs = Redshift(username='test', password='test', host='test', db='test', port=123) - self.rs.DO_PARSE_BOOLS = True self.tbl = Table([['ID', 'Name'], [1, 'Jim'], @@ -38,7 +37,10 @@ def setUp(self): ]) self.mapping = self.rs.generate_data_types(self.tbl) + self.rs.DO_PARSE_BOOLS = True self.mapping2 = self.rs.generate_data_types(self.tbl2) + self.rs.DO_PARSE_BOOLS = False + self.mapping3 = self.rs.generate_data_types(self.tbl2) def test_split_full_table_name(self): schema, table = Redshift.split_full_table_name('some_schema.some_table') @@ -60,8 +62,12 @@ def test_combine_schema_and_table_name(self): def test_data_type(self): # Test bool + self.rs.DO_PARSE_BOOLS = True self.assertEqual(self.rs.data_type(1, ''), 'bool') self.assertEqual(self.rs.data_type(True, ''), 'bool') + self.rs.DO_PARSE_BOOLS = False + self.assertEqual(self.rs.data_type(1, ''), 'int') + self.assertEqual(self.rs.data_type(True, ''), 'varchar') # Test smallint # Currently smallints are coded as ints self.assertEqual(self.rs.data_type(2, ''), 'int') @@ -90,6 +96,10 @@ def test_generate_data_types(self): self.assertEqual( self.mapping2['type_list'], ['varchar', 'varchar', 'float', 'varchar', 'float', 'bool', 'varchar']) + + self.assertEqual( + self.mapping3['type_list'], + ['varchar', 'varchar', 'float', 'varchar', 'float', 'int', 'varchar']) # Test correct lengths self.assertEqual(self.mapping['longest'], [1, 5])