Skip to content

Commit

Permalink
初始化 notify 时自动根据 workflow audit 取 workflow (#2363)
Browse files Browse the repository at this point in the history
* 初始化 notify 时自动根据 workflow audit 取 workflow

* 新增初始化时不传 workflow 的初始化
  • Loading branch information
LeoQuote authored Nov 6, 2023
1 parent ba985e9 commit f959849
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 41 deletions.
10 changes: 10 additions & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@ class WorkflowAudit(models.Model):
create_time = models.DateTimeField("申请时间", auto_now_add=True)
sys_time = models.DateTimeField("系统时间", auto_now=True)

def get_workflow(self):
"""尝试从 audit 中取出 workflow"""
if self.workflow_type == WorkflowType.QUERY:
return QueryPrivilegesApply.objects.get(apply_id=self.workflow_id)
elif self.workflow_type == WorkflowType.SQL_REVIEW:
return SqlWorkflow.objects.get(id=self.workflow_id)
elif self.workflow_type == WorkflowType.ARCHIVE:
return ArchiveConfig.objects.get(id=self.workflow_id)
raise ValueError("无法获取到关联工单")

def __int__(self):
return self.audit_id

Expand Down
54 changes: 22 additions & 32 deletions sql/notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,25 @@ class My2SqlResult:
error: str = ""


@dataclass
class Notifier:
name = "base"
sys_config_key: str = ""

def __init__(
self,
workflow: Union[SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult],
sys_config: SysConfig,
audit: WorkflowAudit = None,
audit_detail: WorkflowAuditDetail = None,
event_type: EventType = EventType.AUDIT,
):
self.workflow = workflow
self.audit = audit
self.audit_detail = audit_detail
self.event_type = event_type
self.sys_config = sys_config
workflow: Union[SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult]
sys_config: SysConfig = None
# init false, class property, 不是 instance property
name: str = field(init=False, default="base")
sys_config_key: str = field(init=False, default="")
event_type: EventType = EventType.AUDIT
audit: WorkflowAudit = None
audit_detail: WorkflowAuditDetail = None

def __post_init__(self):
if not self.workflow:
if not self.audit:
raise ValueError("需要提供 WorkflowAudit 或 workflow")
self.workflow = self.audit.get_workflow()
# 防止 get_auditor 显式的传了个 None
if not self.sys_config:
self.sys_config = SysConfig()

def render(self):
raise NotImplementedError
Expand All @@ -91,12 +93,9 @@ def run(self):


class GenericWebhookNotifier(Notifier):
name = "generic_webhook"
name: str = "generic_webhook"
sys_config_key: str = "generic_webhook_url"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.request_data = None
request_data: dict = None

def render(self):
self.request_data = {}
Expand Down Expand Up @@ -133,13 +132,9 @@ class LegacyMessage:
msg_cc: List[Users] = field(default_factory=list)


@dataclass
class LegacyRender(Notifier):
messages: List[LegacyMessage]
sys_config_key: str = ""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.messages = []
messages: List[LegacyMessage] = field(default_factory=list)

def render_audit(self):
# 获取审核信息
Expand Down Expand Up @@ -476,11 +471,6 @@ def auto_notify(
加载所有的 notifier, 调用 notifier 的 render 和 send 方法
内部方法, 有数据库查询, 为了方便测试, 请勿使用 async_task 调用, 防止 patch 后调用失败
"""
if not workflow and event_type == EventType.AUDIT:
if audit.workflow_type == 1:
workflow = QueryPrivilegesApply.objects.get(apply_id=audit.workflow_id)
if audit.workflow_type == 2:
workflow = SqlWorkflow.objects.get(id=audit.workflow_id)
for notifier in settings.ENABLED_NOTIFIERS:
file, _class = notifier.split(":")
try:
Expand Down
26 changes: 25 additions & 1 deletion sql/test_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test_base_notifier(self):
@patch("sql.notify.FeishuWebhookNotifier.run")
def test_auto_notify(self, mock_run):
with self.settings(ENABLED_NOTIFIERS=("sql.notify:FeishuWebhookNotifier",)):
auto_notify(self.sys_config, event_type=EventType.EXECUTE)
auto_notify(self.sys_config, event_type=EventType.EXECUTE, workflow=self.wf)
mock_run.assert_called_once()

@patch("sql.notify.auto_notify")
Expand Down Expand Up @@ -280,6 +280,17 @@ def test_legacy_render_audit(self):
notifier.render()
self.assertEqual(len(notifier.messages), 1)
self.assertIn("新的工单申请", notifier.messages[0].msg_title)
# 测试一下不传 workflow
notifier = LegacyRender(
event_type=EventType.AUDIT,
workflow=None,
audit=self.audit_wf,
audit_detail=self.audit_wf_detail,
sys_config=self.sys_config,
)
notifier.render()
self.assertEqual(len(notifier.messages), 1)
self.assertIn("新的工单申请", notifier.messages[0].msg_title)

def test_legacy_render_query_audit(self):
# 默认是库权限的
Expand Down Expand Up @@ -494,10 +505,13 @@ def tearDownClass(cls):
def setUp(self):
self.patcher = patch("sql.notify.MsgSender")
self.mock_msg_sender = self.patcher.start()
self.get_workflow_patcher = patch("sql.models.WorkflowAudit.get_workflow")
self.mock_get_workflow = self.get_workflow_patcher.start()
self.sys_config = SysConfig()

def tearDown(self):
self.patcher.stop()
self.get_workflow_patcher.stop()

def generate_notifier(self, module) -> Notifier:
return module(workflow=None, audit=self.audit_wf, sys_config=self.sys_config)
Expand Down Expand Up @@ -561,3 +575,13 @@ def test_mail(self):
]
notifier.send()
mocker.assert_called_once()


def test_override_sys_key():
"""dataclass 的继承有时候让人有点困惑, 在这里补一个测试确认可以正常覆盖一些值"""

class OverrideNotifier(Notifier):
sys_config_key = "test"

n = OverrideNotifier(workflow="test")
assert n.sys_config_key == "test"
10 changes: 2 additions & 8 deletions sql/utils/workflow_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,8 @@ def review_info(self) -> (str, str):

def get_workflow(self):
"""尝试从 audit 中取出 workflow"""
if self.audit.workflow_type == WorkflowType.QUERY:
self.workflow = QueryPrivilegesApply.objects.get(
apply_id=self.audit.workflow_id
)
elif self.audit.workflow_type == WorkflowType.SQL_REVIEW:
self.workflow = SqlWorkflow.objects.get(id=self.audit.workflow_id)
elif self.audit.workflow_type == WorkflowType.ARCHIVE:
self.workflow = ArchiveConfig.objects.get(id=self.audit.workflow_id)
self.workflow = self.audit.get_workflow()
if self.audit.workflow_type == WorkflowType.ARCHIVE:
self.resource_group = self.audit.group_name
self.resource_group_id = self.audit.group_id

Expand Down

0 comments on commit f959849

Please sign in to comment.