Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

engine增加escape_string用于处理字符串参数转义 #2107

Merged
merged 1 commit into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions sql/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def table_list(request):
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
data = query_engine.get_group_tables_by_db(db_name=db_name)
res = {"status": 0, "data": data}
except Instance.DoesNotExist:
Expand All @@ -50,13 +51,16 @@ def table_info(request):
db_name = request.GET.get("db_name", "")
tb_name = request.GET.get("tb_name", "")
db_type = request.GET.get("db_type", "")

if instance_name and db_name and tb_name:
data = {}
try:
instance = Instance.objects.get(
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
tb_name = query_engine.escape_string(tb_name)
data["meta_data"] = query_engine.get_table_meta_data(
db_name=db_name, tb_name=tb_name
)
Expand Down Expand Up @@ -91,8 +95,6 @@ def export(request):
"""导出数据字典"""
instance_name = request.GET.get("instance_name", "")
db_name = request.GET.get("db_name", "")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

try:
instance = user_instances(
Expand All @@ -104,7 +106,7 @@ def export(request):

# 普通用户仅可以获取指定数据库的字典信息
if db_name:
dbs = [db_name]
dbs = [query_engine.escape_string(db_name)]
# 管理员可以导出整个实例的字典信息
elif request.user.is_superuser:
dbs = query_engine.get_all_databases().rows
Expand Down
4 changes: 4 additions & 0 deletions sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def info(self):
"""返回引擎简介"""
return "Base engine"

def escape_string(self, value: str) -> str:
"""参数转义"""
return value

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down
5 changes: 5 additions & 0 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: UTF-8 -*-
from clickhouse_driver import connect
from clickhouse_driver.util.escape import escape_chars_map
from sql.utils.sql_utils import get_syntax_type
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
Expand Down Expand Up @@ -49,6 +50,10 @@ def name(self):
def info(self):
return "ClickHouse engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return "'%s'" % "".join(escape_chars_map.get(c, c) for c in value)

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down
10 changes: 7 additions & 3 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def name(self):
def info(self):
return "MySQL engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return MySQLdb.escape_string(value).decode("utf-8")

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down Expand Up @@ -167,7 +171,7 @@ def get_all_tables(self, db_name, **kwargs):

def get_group_tables_by_db(self, db_name):
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
db_name = self.escape_string(db_name)
data = {}
sql = f"""SELECT TABLE_NAME,
TABLE_COMMENT
Expand All @@ -186,8 +190,8 @@ def get_group_tables_by_db(self, db_name):
def get_table_meta_data(self, db_name, tb_name, **kwargs):
"""数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}"""
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")
db_name = self.escape_string(db_name)
tb_name = self.escape_string(tb_name)
sql = f"""SELECT
TABLE_NAME as table_name,
ENGINE as engine,
Expand Down
15 changes: 10 additions & 5 deletions sql/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def param_edit(request):
instance_id = request.POST.get("instance_id")
variable_name = request.POST.get("variable_name")
variable_value = request.POST.get("runtime_value")
# escape
variable_name = MySQLdb.escape_string(variable_name).decode("utf-8")
variable_value = MySQLdb.escape_string(variable_value).decode("utf-8")

try:
ins = Instance.objects.get(id=instance_id)
Expand Down Expand Up @@ -320,12 +323,10 @@ def instance_resource(request):
result = {"status": 0, "msg": "ok", "data": []}

try:
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")

query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
if resource_type == "database":
resource = query_engine.get_all_databases()
elif resource_type == "schema" and db_name:
Expand Down Expand Up @@ -363,10 +364,14 @@ def describe(request):
db_name = request.POST.get("db_name")
schema_name = request.POST.get("schema_name")
tb_name = request.POST.get("tb_name")

result = {"status": 0, "msg": "ok", "data": []}

try:
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
query_result = query_engine.describe_table(
db_name, tb_name, schema_name=schema_name
)
Expand Down
4 changes: 1 addition & 3 deletions sql/instance_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def create(request):
except Users.DoesNotExist:
return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []})

# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

engine = get_engine(instance=instance)
db_name = engine.escape_string(db_name)
exec_result = engine.execute(
db_name="information_schema", sql=f"create database {db_name};"
)
Expand Down
4 changes: 2 additions & 2 deletions sql/sql_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def optimize_sqltuning(request):
except Instance.DoesNotExist:
result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []}
return HttpResponse(json.dumps(result), content_type="application/json")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

sql_tunning = SqlTuning(
instance_name=instance_name, db_name=db_name, sqltext=sqltext
Expand Down Expand Up @@ -235,6 +233,7 @@ def explain(request):

# 执行获取执行计划语句
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict()
result["data"] = sql_result

Expand Down Expand Up @@ -287,6 +286,7 @@ def optimize_sqltuningadvisor(request):

# 执行获取优化报告
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict()
result["data"] = sql_result

Expand Down
2 changes: 1 addition & 1 deletion sql/sql_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, instance_name, db_name, sqltext):
instance = Instance.objects.get(instance_name=instance_name)
query_engine = get_engine(instance=instance)
self.engine = query_engine
self.db_name = db_name
self.db_name = self.engine.escape_string(db_name)
self.sqltext = sqltext
self.sql_variable = """
select
Expand Down
2 changes: 1 addition & 1 deletion sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,7 +2539,7 @@ def test_param_edit_variable_not_config(
data = {
"instance_id": self.master.id,
"variable_name": "1",
"variable_value": "false",
"runtime_value": "false",
}
r = self.client.post(path="/param/edit/", data=data)
self.assertEqual(
Expand Down
8 changes: 3 additions & 5 deletions sql_api/api_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,10 @@ def post(self, request):
instance = Instance.objects.get(pk=instance_id)

try:
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")

query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
if resource_type == "database":
resource = query_engine.get_all_databases()
elif resource_type == "schema" and db_name:
Expand Down
5 changes: 4 additions & 1 deletion sql_api/api_workflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import MySQLdb
from django.contrib.auth.decorators import permission_required
from django.utils.decorators import method_decorator
from rest_framework import views, generics, status, serializers, permissions
Expand Down Expand Up @@ -60,9 +61,11 @@ def post(self, request):
instance = serializer.get_instance()
# 交给engine进行检测
try:
db_name = request.data["db_name"]
check_engine = get_engine(instance=instance)
db_name = check_engine.escape_string(db_name)
check_result = check_engine.execute_check(
db_name=request.data["db_name"], sql=request.data["full_sql"].strip()
db_name=db_name, sql=request.data["full_sql"].strip()
)
except Exception as e:
raise serializers.ValidationError({"errors": f"{e}"})
Expand Down