From 69c78c0801061af539b3d7cc4e208e3898a02d1a Mon Sep 17 00:00:00 2001 From: Leo Q Date: Tue, 14 Mar 2023 07:26:06 +0800 Subject: [PATCH] add parameters in clickhouse and mysql engine to avoid sql injection (#2062) * add parameters in clickhouse and mysql engine to avoid sql injection * black --- sql/engines/__init__.py | 10 +++++- sql/engines/clickhouse.py | 29 ++++++++++++------ sql/engines/mysql.py | 64 +++++++++++++++++++++++---------------- 3 files changed, 66 insertions(+), 37 deletions(-) diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index a101abed13..50e04d9318 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -147,7 +147,15 @@ def filter_sql(self, sql="", limit_num=0): """给查询语句增加结果级限制或者改写语句, 返回修改后的语句""" return sql.strip() - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + parameters=None, + limit_num=0, + close_conn=True, + **kwargs + ): """实际查询 返回一个ResultSet""" return ResultSet() diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 22216c4be5..481e7c69c3 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -63,11 +63,12 @@ def server_version(self): def get_table_engine(self, tb_name): """获取某个table的engine type""" + [database, name] = tb_name.split(".") sql = f"""select engine from system.tables - where database='{tb_name.split('.')[0]}' - and name='{tb_name.split('.')[1]}'""" - query_result = self.query(sql=sql) + where database=%s + and name=%s""" + query_result = self.query(sql=sql, parameters=(database, name)) if query_result.rows: result = {"status": 1, "engine": query_result.rows[0][0]} else: @@ -104,30 +105,38 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): from system.columns where - database = '{db_name}' - and table = '{tb_name}';""" - result = self.query(db_name=db_name, sql=sql) + database = %s + and table = %s;""" + result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name)) column_list = [row[0] for row in result.rows] result.rows = column_list return result def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet 类似查询""" - sql = f"show create table `{tb_name}`;" - result = self.query(db_name=db_name, sql=sql) + sql = f"show create table %s;" + result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,)) result.rows[0] = (tb_name,) + ( result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "), ) return result - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + parameters=None, + limit_num=0, + close_conn=True, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) try: conn = self.get_connection(db_name=db_name) cursor = conn.cursor() - cursor.execute(sql) + cursor.execute(sql, parameters=parameters) if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) else: diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index ee47fd4240..da4245ac20 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -174,8 +174,8 @@ def get_group_tables_by_db(self, db_name): FROM information_schema.TABLES WHERE - TABLE_SCHEMA='{db_name}';""" - result = self.query(db_name=db_name, sql=sql) + TABLE_SCHEMA=%s;""" + result = self.query(db_name=db_name, sql=sql, parameters=(db_name,)) for row in result.rows: table_name, table_cmt = row[0], row[1] if table_name[0] not in data: @@ -208,9 +208,9 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs): FROM information_schema.TABLES WHERE - TABLE_SCHEMA='{db_name}' - AND TABLE_NAME='{tb_name}'""" - _meta_data = self.query(db_name, sql) + TABLE_SCHEMA=%s + AND TABLE_NAME=%s""" + _meta_data = self.query(db_name, sql, parameters=(db_name, tb_name)) return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]} def get_table_desc_data(self, db_name, tb_name, **kwargs): @@ -227,10 +227,10 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs): FROM information_schema.COLUMNS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}' + TABLE_SCHEMA = %s + AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION;""" - _desc_data = self.query(db_name, sql) + _desc_data = self.query(db_name, sql, parameters=(db_name, tb_name)) return {"column_list": _desc_data.column_list, "rows": _desc_data.rows} def get_table_index_data(self, db_name, tb_name, **kwargs): @@ -247,18 +247,19 @@ def get_table_index_data(self, db_name, tb_name, **kwargs): FROM information_schema.STATISTICS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}';""" - _index_data = self.query(db_name, sql) + TABLE_SCHEMA = %s + AND TABLE_NAME = %s;""" + _index_data = self.query(db_name, sql, parameters=(db_name, tb_name)) return {"column_list": _index_data.column_list, "rows": _index_data.rows} def get_tables_metas_data(self, db_name, **kwargs): """获取数据库所有表格信息,用作数据字典导出接口""" - sql_tbs = ( - f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';" - ) + sql_tbs = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=%s;" tbs = self.query( - sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False + sql=sql_tbs, + parameters=(db_name,), + cursorclass=MySQLdb.cursors.DictCursor, + close_conn=False, ).rows table_metas = [] for tb in tbs: @@ -275,9 +276,12 @@ def get_tables_metas_data(self, db_name, **kwargs): _meta["ENGINE_KEYS"] = engine_keys _meta["TABLE_INFO"] = tb sql_cols = f"""SELECT * FROM INFORMATION_SCHEMA.COLUMNS - WHERE TABLE_SCHEMA='{tb['TABLE_SCHEMA']}' AND TABLE_NAME='{tb['TABLE_NAME']}';""" + WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s;""" _meta["COLUMNS"] = self.query( - sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False + sql=sql_cols, + parameters=(tb["TABLE_SCHEMA"], tb["TABLE_NAME"]), + cursorclass=MySQLdb.cursors.DictCursor, + close_conn=False, ).rows table_metas.append(_meta) return table_metas @@ -295,18 +299,18 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): FROM information_schema.COLUMNS WHERE - TABLE_SCHEMA = '{db_name}' - AND TABLE_NAME = '{tb_name}' + TABLE_SCHEMA = %s + AND TABLE_NAME = %s ORDER BY ORDINAL_POSITION;""" - result = self.query(db_name=db_name, sql=sql) + result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name)) column_list = [row[0] for row in result.rows] result.rows = column_list return result def describe_table(self, db_name, tb_name, **kwargs): """return ResultSet 类似查询""" - sql = f"show create table `{tb_name}`;" - result = self.query(db_name=db_name, sql=sql) + sql = f"show create table %s;" + result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,)) return result @staticmethod @@ -325,7 +329,15 @@ def result_set_binary_as_hex(result_set): result_set.rows = tuple(new_rows) return result_set - def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): + def query( + self, + db_name=None, + sql="", + parameters=None, + limit_num=0, + close_conn=True, + **kwargs, + ): """返回 ResultSet""" result_set = ResultSet(full_sql=sql) max_execution_time = kwargs.get("max_execution_time", 0) @@ -338,7 +350,7 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs): cursor.execute(f"set session max_execution_time={max_execution_time};") except MySQLdb.OperationalError: pass - effect_row = cursor.execute(sql) + effect_row = cursor.execute(sql, args=parameters) if int(limit_num) > 0: rows = cursor.fetchmany(size=int(limit_num)) else: @@ -518,14 +530,14 @@ def execute_workflow(self, workflow): # inception执行 return self.inc_engine.execute(workflow) - def execute(self, db_name=None, sql="", close_conn=True): + def execute(self, db_name=None, sql="", parameters=None, close_conn=True): """原生执行语句""" result = ResultSet(full_sql=sql) conn = self.get_connection(db_name=db_name) try: cursor = conn.cursor() for statement in sqlparse.split(sql): - cursor.execute(statement) + cursor.execute(statement, args=parameters) conn.commit() cursor.close() except Exception as e: