Skip to content

Commit

Permalink
add parameters in clickhouse and mysql engine to avoid sql injection (#…
Browse files Browse the repository at this point in the history
…2062)

* add parameters in clickhouse and mysql engine to avoid sql injection

* black
  • Loading branch information
LeoQuote authored Mar 13, 2023
1 parent 96094b4 commit 69c78c0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 37 deletions.
10 changes: 9 additions & 1 deletion sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,15 @@ def filter_sql(self, sql="", limit_num=0):
"""给查询语句增加结果级限制或者改写语句, 返回修改后的语句"""
return sql.strip()

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs
):
"""实际查询 返回一个ResultSet"""
return ResultSet()

Expand Down
29 changes: 19 additions & 10 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def server_version(self):

def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
[database, name] = tb_name.split(".")
sql = f"""select engine
from system.tables
where database='{tb_name.split('.')[0]}'
and name='{tb_name.split('.')[1]}'"""
query_result = self.query(sql=sql)
where database=%s
and name=%s"""
query_result = self.query(sql=sql, parameters=(database, name))
if query_result.rows:
result = {"status": 1, "engine": query_result.rows[0][0]}
else:
Expand Down Expand Up @@ -104,30 +105,38 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
from
system.columns
where
database = '{db_name}'
and table = '{tb_name}';"""
result = self.query(db_name=db_name, sql=sql)
database = %s
and table = %s;"""
result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name))
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
sql = f"show create table %s;"
result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,))

result.rows[0] = (tb_name,) + (
result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "),
)
return result

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
cursor.execute(sql, parameters=parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down
64 changes: 38 additions & 26 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def get_group_tables_by_db(self, db_name):
FROM
information_schema.TABLES
WHERE
TABLE_SCHEMA='{db_name}';"""
result = self.query(db_name=db_name, sql=sql)
TABLE_SCHEMA=%s;"""
result = self.query(db_name=db_name, sql=sql, parameters=(db_name,))
for row in result.rows:
table_name, table_cmt = row[0], row[1]
if table_name[0] not in data:
Expand Down Expand Up @@ -208,9 +208,9 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.TABLES
WHERE
TABLE_SCHEMA='{db_name}'
AND TABLE_NAME='{tb_name}'"""
_meta_data = self.query(db_name, sql)
TABLE_SCHEMA=%s
AND TABLE_NAME=%s"""
_meta_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]}

def get_table_desc_data(self, db_name, tb_name, **kwargs):
Expand All @@ -227,10 +227,10 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.COLUMNS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}'
TABLE_SCHEMA = %s
AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION;"""
_desc_data = self.query(db_name, sql)
_desc_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _desc_data.column_list, "rows": _desc_data.rows}

def get_table_index_data(self, db_name, tb_name, **kwargs):
Expand All @@ -247,18 +247,19 @@ def get_table_index_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.STATISTICS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}';"""
_index_data = self.query(db_name, sql)
TABLE_SCHEMA = %s
AND TABLE_NAME = %s;"""
_index_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _index_data.column_list, "rows": _index_data.rows}

def get_tables_metas_data(self, db_name, **kwargs):
"""获取数据库所有表格信息,用作数据字典导出接口"""
sql_tbs = (
f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';"
)
sql_tbs = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=%s;"
tbs = self.query(
sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
sql=sql_tbs,
parameters=(db_name,),
cursorclass=MySQLdb.cursors.DictCursor,
close_conn=False,
).rows
table_metas = []
for tb in tbs:
Expand All @@ -275,9 +276,12 @@ def get_tables_metas_data(self, db_name, **kwargs):
_meta["ENGINE_KEYS"] = engine_keys
_meta["TABLE_INFO"] = tb
sql_cols = f"""SELECT * FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA='{tb['TABLE_SCHEMA']}' AND TABLE_NAME='{tb['TABLE_NAME']}';"""
WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s;"""
_meta["COLUMNS"] = self.query(
sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
sql=sql_cols,
parameters=(tb["TABLE_SCHEMA"], tb["TABLE_NAME"]),
cursorclass=MySQLdb.cursors.DictCursor,
close_conn=False,
).rows
table_metas.append(_meta)
return table_metas
Expand All @@ -295,18 +299,18 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
FROM
information_schema.COLUMNS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}'
TABLE_SCHEMA = %s
AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION;"""
result = self.query(db_name=db_name, sql=sql)
result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name))
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
sql = f"show create table %s;"
result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,))
return result

@staticmethod
Expand All @@ -325,7 +329,15 @@ def result_set_binary_as_hex(result_set):
result_set.rows = tuple(new_rows)
return result_set

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
max_execution_time = kwargs.get("max_execution_time", 0)
Expand All @@ -338,7 +350,7 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
cursor.execute(f"set session max_execution_time={max_execution_time};")
except MySQLdb.OperationalError:
pass
effect_row = cursor.execute(sql)
effect_row = cursor.execute(sql, args=parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down Expand Up @@ -518,14 +530,14 @@ def execute_workflow(self, workflow):
# inception执行
return self.inc_engine.execute(workflow)

def execute(self, db_name=None, sql="", close_conn=True):
def execute(self, db_name=None, sql="", parameters=None, close_conn=True):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
try:
cursor = conn.cursor()
for statement in sqlparse.split(sql):
cursor.execute(statement)
cursor.execute(statement, args=parameters)
conn.commit()
cursor.close()
except Exception as e:
Expand Down

0 comments on commit 69c78c0

Please sign in to comment.