diff --git a/sql/models.py b/sql/models.py index 0da4dcf1b8..aaa0b00a3e 100755 --- a/sql/models.py +++ b/sql/models.py @@ -1,4 +1,6 @@ # -*- coding: UTF-8 -*- +from typing import Optional + from django.db import models from django.contrib.auth.models import AbstractUser from mirage import fields @@ -240,7 +242,36 @@ class Meta: ) -class SqlWorkflow(models.Model): +class WorkflowAuditMixin: + @property + def workflow_type(self): + if isinstance(self, SqlWorkflow): + return WorkflowType.SQL_REVIEW + elif isinstance(self, ArchiveConfig): + return WorkflowType.ARCHIVE + elif isinstance(self, QueryPrivilegesApply): + return WorkflowType.QUERY + + @property + def workflow_pk_field(self): + if isinstance(self, SqlWorkflow): + return "id" + elif isinstance(self, ArchiveConfig): + return "id" + elif isinstance(self, QueryPrivilegesApply): + return "apply_id" + + def get_audit(self) -> Optional["WorkflowAudit"]: + try: + return WorkflowAudit.objects.get( + workflow_type=self.workflow_type, + workflow_id=getattr(self, self.workflow_pk_field), + ) + except WorkflowAudit.DoesNotExist: + return None + + +class SqlWorkflow(models.Model, WorkflowAuditMixin): """ 存放各个SQL上线工单的基础内容 """ @@ -419,7 +450,7 @@ class Meta: verbose_name_plural = "工作流日志" -class QueryPrivilegesApply(models.Model): +class QueryPrivilegesApply(models.Model, WorkflowAuditMixin): """ 查询权限申请记录表 """ @@ -687,7 +718,7 @@ class Meta: verbose_name_plural = "实例参数修改历史" -class ArchiveConfig(models.Model): +class ArchiveConfig(models.Model, WorkflowAuditMixin): """ 归档配置表 """ diff --git a/sql/notify.py b/sql/notify.py index c8e703aa12..92c2a56d88 100755 --- a/sql/notify.py +++ b/sql/notify.py @@ -54,7 +54,9 @@ class My2SqlResult: @dataclass class Notifier: - workflow: Union[SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult] + workflow: Union[ + SqlWorkflow, ArchiveConfig, QueryPrivilegesApply, My2SqlResult + ] = None sys_config: SysConfig = None # init false, class property, 不是 instance property name: str = field(init=False, default="base") @@ -64,10 +66,12 @@ class Notifier: audit_detail: WorkflowAuditDetail = None def __post_init__(self): + if not self.audit and not self.workflow: + raise ValueError("需要提供 WorkflowAudit 或 workflow") if not self.workflow: - if not self.audit: - raise ValueError("需要提供 WorkflowAudit 或 workflow") self.workflow = self.audit.get_workflow() + if not self.audit and not isinstance(self.workflow, My2SqlResult): + self.audit = self.workflow.get_audit() # 防止 get_auditor 显式的传了个 None if not self.sys_config: self.sys_config = SysConfig() @@ -495,7 +499,7 @@ def auto_notify( def notify_for_execute(workflow: SqlWorkflow, sys_config: SysConfig = None): if not sys_config: sys_config = SysConfig() - auto_notify(workflow=workflow, sys_config=sys_config) + auto_notify(workflow=workflow, sys_config=sys_config, event_type=EventType.EXECUTE) def notify_for_audit( @@ -514,6 +518,7 @@ def notify_for_audit( audit=workflow_audit, audit_detail=workflow_audit_detail, sys_config=sys_config, + event_type=EventType.AUDIT, ) diff --git a/sql/test_notify.py b/sql/test_notify.py index ecb598939c..3d5bb6b648 100644 --- a/sql/test_notify.py +++ b/sql/test_notify.py @@ -196,6 +196,10 @@ def test_base_notifier(self): n.sys_config_key = "not-foo" self.assertFalse(n.should_run()) + def test_no_workflow_and_audit(self): + with self.assertRaises(ValueError): + Notifier(workflow=None, audit=None) + @patch("sql.notify.FeishuWebhookNotifier.run") def test_auto_notify(self, mock_run): with self.settings(ENABLED_NOTIFIERS=("sql.notify:FeishuWebhookNotifier",)): @@ -206,7 +210,9 @@ def test_auto_notify(self, mock_run): def test_notify_for_execute(self, mock_auto_notify: Mock): """测试适配器""" notify_for_execute(self.wf) - mock_auto_notify.assert_called_once_with(workflow=self.wf, sys_config=ANY) + mock_auto_notify.assert_called_once_with( + workflow=self.wf, sys_config=ANY, event_type=EventType.EXECUTE + ) @patch("sql.notify.auto_notify") def test_notify_for_audit(self, mock_auto_notify: Mock): @@ -216,6 +222,7 @@ def test_notify_for_audit(self, mock_auto_notify: Mock): ) mock_auto_notify.assert_called_once_with( workflow=None, + event_type=EventType.AUDIT, sys_config=ANY, audit=self.audit_wf, audit_detail=self.audit_wf_detail, @@ -583,5 +590,5 @@ def test_override_sys_key(): class OverrideNotifier(Notifier): sys_config_key = "test" - n = OverrideNotifier(workflow="test") + n = OverrideNotifier(workflow=Mock()) assert n.sys_config_key == "test" diff --git a/sql/utils/workflow_audit.py b/sql/utils/workflow_audit.py index 8ef367bad8..261d64b313 100644 --- a/sql/utils/workflow_audit.py +++ b/sql/utils/workflow_audit.py @@ -9,7 +9,6 @@ from django.contrib.auth.models import Group from django.utils import timezone -from django.core.exceptions import ObjectDoesNotExist from django.conf import settings from sql.engines.models import ReviewResult @@ -76,7 +75,6 @@ class AuditV2: sys_config: SysConfig = field(default_factory=SysConfig) audit: WorkflowAudit = None workflow_type: WorkflowType = WorkflowType.SQL_REVIEW - workflow_pk_field: str = "id" # 归档表中没有下面两个参数, 所以对归档表来说一下两参数必传 resource_group: str = "" resource_group_id: int = 0 @@ -86,22 +84,17 @@ def __post_init__(self): if not self.audit: raise ValueError("需要提供 WorkflowAudit 或 workflow") self.get_workflow() + self.workflow_type = self.workflow.workflow_type if isinstance(self.workflow, SqlWorkflow): - self.workflow_type = WorkflowType.SQL_REVIEW - self.workflow_pk_field = "id" self.resource_group = self.workflow.group_name self.resource_group_id = self.workflow.group_id elif isinstance(self.workflow, ArchiveConfig): - self.workflow_type = WorkflowType.ARCHIVE - self.workflow_pk_field = "id" try: group_in_db = ResourceGroup.objects.get(group_name=self.resource_group) self.resource_group_id = group_in_db.group_id except ResourceGroup.DoesNotExist: raise AuditException(f"参数错误, 未发现资源组 {self.resource_group}") elif isinstance(self.workflow, QueryPrivilegesApply): - self.workflow_type = WorkflowType.QUERY - self.workflow_pk_field = "apply_id" self.resource_group = self.workflow.group_name self.resource_group_id = self.workflow.group_id # 该方法可能获取不到相关的审批流, 但是也不要报错, 因为有的时候是新建工单, 此时还没有审批流 @@ -241,7 +234,7 @@ def create_audit(self) -> str: self.audit = WorkflowAudit( group_id=group_id, group_name=group_name, - workflow_id=self.workflow.__getattribute__(self.workflow_pk_field), + workflow_id=self.workflow.pk, workflow_type=self.workflow_type, workflow_title=workflow_title, audit_auth_groups=audit_setting.audit_auth_group_in_db, @@ -365,17 +358,8 @@ def get_audit_info(self) -> Optional[WorkflowAudit]: """尝试根据 workflow 取出审批工作流""" if self.audit: return self.audit - try: - self.audit = WorkflowAudit.objects.get( - workflow_type=self.workflow_type, - workflow_id=getattr(self.workflow, self.workflow_pk_field), - ) - if self.audit.workflow_type == WorkflowType.ARCHIVE: - self.resource_group = self.audit.group_name - self.resource_group_id = self.audit.group_id - return self.audit - except ObjectDoesNotExist: - return None + self.audit = self.workflow.get_audit() + return self.audit def operate_pass(self, actor: Users, remark: str) -> WorkflowAuditDetail: # 判断是否还有下一级审核 @@ -667,7 +651,6 @@ def get_auditor( sys_config: SysConfig = None, audit: WorkflowAudit = None, workflow_type: WorkflowType = WorkflowType.SQL_REVIEW, - workflow_pk_field: str = "id", # 归档表中没有下面两个参数, 所以对归档表来说一下两参数必传 resource_group: str = "", resource_group_id: int = 0, @@ -678,7 +661,6 @@ def get_auditor( return auditor( workflow=workflow, workflow_type=workflow_type, - workflow_pk_field=workflow_pk_field, sys_config=sys_config, audit=audit, resource_group=resource_group,