diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 9d0bce3b342b4..806bd7f2b7c93 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -704,45 +704,29 @@ def test_complex(self): # Complex data type should raise error pytest.raises(ValueError, df.to_sql, 'test_complex', self.conn) - def test_to_sql_index_label(self): - temp_frame = DataFrame({'col1': range(4)}) - + @pytest.mark.parametrize("index_name,index_label,expected", [ # no index name, defaults to 'index' - sql.to_sql(temp_frame, 'test_index_label', self.conn) - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == 'index' - + (None, None, "index"), # specifying index_label - sql.to_sql(temp_frame, 'test_index_label', self.conn, - if_exists='replace', index_label='other_label') - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == "other_label" - + (None, "other_label", "other_label"), # using the index name - temp_frame.index.name = 'index_name' - sql.to_sql(temp_frame, 'test_index_label', self.conn, - if_exists='replace') - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == "index_name" - + ("index_name", None, "index_name"), # has index name, but specifying index_label - sql.to_sql(temp_frame, 'test_index_label', self.conn, - if_exists='replace', index_label='other_label') - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == "other_label" - + ("index_name", "other_label", "other_label"), # index name is integer - temp_frame.index.name = 0 - sql.to_sql(temp_frame, 'test_index_label', self.conn, - if_exists='replace') - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == "0" - - temp_frame.index.name = None + (0, None, "0"), + # index name is None but index label is integer + (None, 0, "0"), + ]) + def test_to_sql_index_label(self, index_name, + index_label, expected): + temp_frame = DataFrame({'col1': range(4)}) + temp_frame.index.name = index_name + query = 'SELECT * FROM test_index_label' sql.to_sql(temp_frame, 'test_index_label', self.conn, - if_exists='replace', index_label=0) - frame = sql.read_sql_query('SELECT * FROM test_index_label', self.conn) - assert frame.columns[0] == "0" + index_label=index_label) + frame = sql.read_sql_query(query, self.conn) + assert frame.columns[0] == expected def test_to_sql_index_label_multiindex(self): temp_frame = DataFrame({'col1': range(4)},