diff --git a/deepfence_bootstrapper/controls/controls.go b/deepfence_bootstrapper/controls/controls.go index e49e3d114e..751c79b61a 100644 --- a/deepfence_bootstrapper/controls/controls.go +++ b/deepfence_bootstrapper/controls/controls.go @@ -4,6 +4,7 @@ package controls import ( + "errors" "fmt" "os/exec" "strings" @@ -43,7 +44,9 @@ func SetClusterAgentControls(k8sClusterName string) { func(req ctl.SendAgentDiagnosticLogsRequest) error { log.Info().Msg("Generate Cluster Agent Diagnostic Logs") return SendAgentDiagnosticLogs(req, - []string{"/var/log/supervisor", "/var/log/fenced/compliance-scan-logs", "/var/log/deepfenced"}, + []string{"/var/log/supervisor", + "/var/log/fenced/compliance-scan-logs", + "/var/log/deepfenced"}, []string{}) }) if err != nil { @@ -81,11 +84,16 @@ func SetAgentControls() { if err != nil { return err } - err = scanner.RunComplianceScan() - if err != nil { - log.Error().Msgf("Error from scan: %+v", err) - return err - } + + log.Info().Msg("StartComplianceScan Starting") + //We need to run this in a goroutine else it will block the + //fetch and execution of controls + go func() { + err := scanner.RunComplianceScan() + if err != nil { + log.Error().Msgf("Error from RunComplianceScan: %+v", err) + } + }() return nil }) if err != nil { @@ -155,4 +163,31 @@ func SetAgentControls() { log.Error().Msgf("set controls: %v", err) } + err = router.RegisterControl(ctl.StopVulnerabilityScan, + func(req ctl.StopVulnerabilityScanRequest) error { + log.Info().Msg("StopVulnerabilityScanRequest called") + return router.StopVulnerabilityScan(req) + }) + if err != nil { + log.Error().Msgf("set controls: %v", err) + } + + err = router.RegisterControl(ctl.StopComplianceScan, + func(req ctl.StopComplianceScanRequest) error { + log.Info().Msg("StopComplianceScanRequest called") + scanId, ok := req.BinArgs["scan_id"] + var err error + if ok { + retVal := linuxScanner.StopScan(scanId) + if retVal == false { + err = errors.New("Failed to stop scan") + } + } else { + err = errors.New("Missing scan id in the StopComplianceScanRequest") + } + return err + }) + if err != nil { + log.Error().Msgf("set controls: %v", err) + } } diff --git a/deepfence_bootstrapper/router/common.go b/deepfence_bootstrapper/router/common.go index 24086a8cfb..eccb1bae09 100644 --- a/deepfence_bootstrapper/router/common.go +++ b/deepfence_bootstrapper/router/common.go @@ -16,7 +16,18 @@ const ( var controls map[ctl.ActionID]func(req []byte) error var controls_guard sync.RWMutex -func RegisterControl[T ctl.StartVulnerabilityScanRequest | ctl.StartSecretScanRequest | ctl.StartComplianceScanRequest | ctl.StartMalwareScanRequest | ctl.StartAgentUpgradeRequest | ctl.SendAgentDiagnosticLogsRequest | ctl.DisableAgentPluginRequest | ctl.EnableAgentPluginRequest | ctl.StopSecretScanRequest | ctl.StopMalwareScanRequest](id ctl.ActionID, callback func(req T) error) error { +func RegisterControl[T ctl.StartVulnerabilityScanRequest | + ctl.StartSecretScanRequest | + ctl.StartComplianceScanRequest | + ctl.StartMalwareScanRequest | + ctl.StartAgentUpgradeRequest | + ctl.SendAgentDiagnosticLogsRequest | + ctl.DisableAgentPluginRequest | + ctl.EnableAgentPluginRequest | + ctl.StopSecretScanRequest | + ctl.StopMalwareScanRequest | + ctl.StopVulnerabilityScanRequest | + ctl.StopComplianceScanRequest](id ctl.ActionID, callback func(req T) error) error { controls_guard.Lock() defer controls_guard.Unlock() diff --git a/deepfence_bootstrapper/router/generate_sbom.go b/deepfence_bootstrapper/router/generate_sbom.go index 51d7d0fe9d..19c1a6ff82 100644 --- a/deepfence_bootstrapper/router/generate_sbom.go +++ b/deepfence_bootstrapper/router/generate_sbom.go @@ -3,6 +3,7 @@ package router import ( "context" "errors" + "fmt" "os" ctl "github.com/deepfence/ThreatMapper/deepfence_utils/controls" @@ -172,3 +173,19 @@ func GetPackageScannerJobCount() int32 { } return jobReport.RunningJobs } + +func StopVulnerabilityScan(req ctl.StopVulnerabilityScanRequest) error { + fmt.Printf("Stop Vulnerability Scan : %v\n", req) + conn, err := createPackageScannerConn() + if err != nil { + fmt.Printf("StopVulnerabilityScanJob::error in creating Vulnerability scanner client: %s\n", err.Error()) + return err + } + defer conn.Close() + client := pb.NewScannersClient(conn) + var greq pb.StopScanRequest + greq.ScanId = req.BinArgs["scan_id"] + + _, err = client.StopScan(context.Background(), &greq) + return err +} diff --git a/deepfence_ctl/cmd/scan.go b/deepfence_ctl/cmd/scan.go index d297b0450a..e1d5541ff7 100644 --- a/deepfence_ctl/cmd/scan.go +++ b/deepfence_ctl/cmd/scan.go @@ -431,6 +431,27 @@ var scanStopSubCmd = &cobra.Command{ ScanType: "MalwareScan", }) res, err = http.Client().MalwareScanAPI.StopMalwareScanExecute(req) + case "vulnerability": + req := http.Client().VulnerabilityAPI.StopVulnerabilityScan(context.Background()) + req = req.ModelStopScanRequest(deepfence_server_client.ModelStopScanRequest{ + ScanId: scan_id, + ScanType: "VulnerabilityScan", + }) + res, err = http.Client().VulnerabilityAPI.StopVulnerabilityScanExecute(req) + case "compliance": + req := http.Client().ComplianceAPI.StopComplianceScan(context.Background()) + req = req.ModelStopScanRequest(deepfence_server_client.ModelStopScanRequest{ + ScanId: scan_id, + ScanType: "ComplianceScan", + }) + res, err = http.Client().ComplianceAPI.StopComplianceScanExecute(req) + case "cloudcompliance": + req := http.Client().ComplianceAPI.StopComplianceScan(context.Background()) + req = req.ModelStopScanRequest(deepfence_server_client.ModelStopScanRequest{ + ScanId: scan_id, + ScanType: "CloudComplianceScan", + }) + res, err = http.Client().ComplianceAPI.StopComplianceScanExecute(req) default: log.Fatal().Msg("Unsupported") } diff --git a/deepfence_server/apiDocs/operation.go b/deepfence_server/apiDocs/operation.go index c8ec8c1e60..43932f9877 100644 --- a/deepfence_server/apiDocs/operation.go +++ b/deepfence_server/apiDocs/operation.go @@ -489,10 +489,10 @@ func (d *OpenApiDocs) AddScansOperations() { // Stop scan d.AddOperation("stopVulnerabilityScan", http.MethodPost, "/deepfence/scan/stop/vulnerability", "Stop Vulnerability Scan", "Stop Vulnerability Scan on agent or registry", - http.StatusAccepted, []string{tagVulnerability}, bearerToken, new(VulnerabilityScanTriggerReq), nil) + http.StatusAccepted, []string{tagVulnerability}, bearerToken, new(StopScanRequest), nil) d.AddOperation("stopComplianceScan", http.MethodPost, "/deepfence/scan/stop/compliance", "Stop Compliance Scan", "Stop Compliance Scan on agent or registry", - http.StatusAccepted, []string{tagCompliance}, bearerToken, new(ComplianceScanTriggerReq), nil) + http.StatusAccepted, []string{tagCompliance}, bearerToken, new(StopScanRequest), nil) d.AddOperation("stopMalwareScan", http.MethodPost, "/deepfence/scan/stop/malware", "Stop Malware Scan", "Stop Malware Scan on agent or registry", http.StatusAccepted, []string{tagMalwareScan}, bearerToken, new(StopScanRequest), nil) diff --git a/deepfence_server/controls/agent.go b/deepfence_server/controls/agent.go index d45c381d2a..a0182ab91c 100644 --- a/deepfence_server/controls/agent.go +++ b/deepfence_server/controls/agent.go @@ -386,6 +386,10 @@ func ExtractStoppingAgentScans(ctx context.Context, nodeId string, action.ID = controls.StopSecretScan case controls.StartMalwareScan: action.ID = controls.StopMalwareScan + case controls.StartVulnerabilityScan: + action.ID = controls.StopVulnerabilityScan + case controls.StartComplianceScan: + action.ID = controls.StopComplianceScan default: log.Info().Msgf("Stop functionality not implemented for action: %d", action.ID) continue diff --git a/deepfence_server/handler/cloud_node.go b/deepfence_server/handler/cloud_node.go index 0963905648..26c6893524 100644 --- a/deepfence_server/handler/cloud_node.go +++ b/deepfence_server/handler/cloud_node.go @@ -88,11 +88,17 @@ func (h *Handler) RegisterCloudNodeAccountHandler(w http.ResponseWriter, r *http if err != nil { log.Error().Msgf("Error getting controls for compliance type: %+v", scan.BenchmarkTypes) } + stopRequested := false + if scan.Status == utils.SCAN_STATUS_CANCELLING { + stopRequested = true + } + scanDetail := model.CloudComplianceScanDetails{ - ScanId: scan.ScanId, - ScanTypes: scan.BenchmarkTypes, - AccountId: monitoredAccountId, - Benchmarks: benchmarks, + ScanId: scan.ScanId, + ScanTypes: scan.BenchmarkTypes, + AccountId: monitoredAccountId, + Benchmarks: benchmarks, + StopRequested: stopRequested, } scanList[scan.ScanId] = scanDetail } @@ -124,11 +130,17 @@ func (h *Handler) RegisterCloudNodeAccountHandler(w http.ResponseWriter, r *http if err != nil { log.Error().Msgf("Error getting controls for compliance type: %+v", scan.BenchmarkTypes) } + + stopRequested := false + if scan.Status == utils.SCAN_STATUS_CANCELLING { + stopRequested = true + } scanDetail := model.CloudComplianceScanDetails{ - ScanId: scan.ScanId, - ScanTypes: scan.BenchmarkTypes, - AccountId: req.CloudAccount, - Benchmarks: benchmarks, + ScanId: scan.ScanId, + ScanTypes: scan.BenchmarkTypes, + AccountId: req.CloudAccount, + Benchmarks: benchmarks, + StopRequested: stopRequested, } scanList[scan.ScanId] = scanDetail } diff --git a/deepfence_server/handler/scan_reports.go b/deepfence_server/handler/scan_reports.go index 35814828bb..2c0c910a55 100644 --- a/deepfence_server/handler/scan_reports.go +++ b/deepfence_server/handler/scan_reports.go @@ -528,19 +528,19 @@ func (h *Handler) SendScanStatus( } func (h *Handler) StopVulnerabilityScanHandler(w http.ResponseWriter, r *http.Request) { - stopScan(w, r, ctl.StartVulnerabilityScan) + h.stopScan(w, r, "StopVulnerabilityScan") } func (h *Handler) StopSecretScanHandler(w http.ResponseWriter, r *http.Request) { - h.stopSecretScan(w, r) + h.stopScan(w, r, "StopSecretScan") } func (h *Handler) StopComplianceScanHandler(w http.ResponseWriter, r *http.Request) { - stopScan(w, r, ctl.StartComplianceScan) + h.stopScan(w, r, "StopComplianceScan") } func (h *Handler) StopMalwareScanHandler(w http.ResponseWriter, r *http.Request) { - h.stopMalwareScan(w, r) + h.stopScan(w, r, "StopMalwareScan") } func (h *Handler) IngestCloudResourcesReportHandler(w http.ResponseWriter, r *http.Request) { @@ -784,18 +784,13 @@ func ingest_scan_report_kafka[T any]( httpext.JSON(respWrite, http.StatusOK, map[string]string{"status": "ok"}) } -func stopScan(w http.ResponseWriter, r *http.Request, action ctl.ActionID) { - -} - -func (h *Handler) stopSecretScan(w http.ResponseWriter, r *http.Request) { +func (h *Handler) stopScan(w http.ResponseWriter, r *http.Request, tag string) { // Stopping scan is on best-effort basis, not guaranteed - defer r.Body.Close() var req model.StopScanRequest err := httpext.DecodeJSON(r, httpext.NoQueryParams, MaxPostRequestSize, &req) if err != nil { - log.Error().Msgf("Failed to DecodeJSON: %v", err) + log.Error().Msgf("%s Failed to DecodeJSON: %v", tag, err) h.respondError(err, w) return } @@ -807,41 +802,16 @@ func (h *Handler) stopSecretScan(w http.ResponseWriter, r *http.Request) { return } - log.Info().Msgf("StopSecretScan request, type: %s, scanid: %s", req.ScanType, req.ScanID) - - err = reporters_scan.StopScan(r.Context(), req.ScanType, req.ScanID) - if err != nil { - log.Error().Msgf("Error in StopScan: %v", err) - h.respondError(&ValidatorError{err: err}, w) - return - } - - h.AuditUserActivity(r, req.ScanType, ACTION_STOP, req, true) - - w.WriteHeader(http.StatusAccepted) -} - -func (h *Handler) stopMalwareScan(w http.ResponseWriter, r *http.Request) { - // Stopping scan is on best-effort basis, not guaranteed - defer r.Body.Close() - var req model.StopScanRequest - err := httpext.DecodeJSON(r, httpext.NoQueryParams, MaxPostRequestSize, &req) - if err != nil { - log.Error().Msgf("StopMalwareScan Failed to DecodeJSON: %v", err) - h.respondError(err, w) - return - } - - err = h.Validator.Struct(req) - if err != nil { - log.Error().Msgf("Failed to validate the request: %v", err) - h.respondError(&ValidatorError{err: err}, w) - return + if req.ScanType == "CloudComplianceScan" { + log.Info().Msgf("CloudComplianceScan request, type: %s, scanid: %s", + req.ScanType, req.ScanID) + err = reporters_scan.StopCloudComplianceScan(r.Context(), req.ScanType, req.ScanID) + } else { + log.Info().Msgf("%s request, type: %s, scanid: %s", + tag, req.ScanType, req.ScanID) + err = reporters_scan.StopScan(r.Context(), req.ScanType, req.ScanID) } - log.Info().Msgf("StopMalwareScan request, type: %s, scanid: %s", req.ScanType, req.ScanID) - - err = reporters_scan.StopScan(r.Context(), req.ScanType, req.ScanID) if err != nil { log.Error().Msgf("Error in StopScan: %v", err) h.respondError(&ValidatorError{err: err}, w) @@ -2119,6 +2089,7 @@ func StartMultiCloudComplianceScan(ctx context.Context, reqs []model.NodeIdentif reqs[0].NodeType) if err != nil { + log.Info().Msgf("Error in AddNewCloudComplianceScan:%v", err) if e, is := err.(*ingesters.AlreadyRunningScanError); is { scanIds = append(scanIds, e.ScanId) continue diff --git a/deepfence_server/ingesters/scan_status.go b/deepfence_server/ingesters/scan_status.go index b136b7325d..1db3f26ecd 100644 --- a/deepfence_server/ingesters/scan_status.go +++ b/deepfence_server/ingesters/scan_status.go @@ -264,12 +264,14 @@ func AddNewCloudComplianceScan(tx WriteDBTransaction, OPTIONAL MATCH (n:%s)-[:SCANNED]->(m) WHERE NOT n.status = $complete AND NOT n.status = $failed + AND NOT n.status = $cancelled AND n.benchmark_types = $benchmark_types RETURN n.node_id, m.agent_running`, neo4jNodeType, scanType), map[string]interface{}{ "node_id": nodeId, "complete": utils.SCAN_STATUS_SUCCESS, "failed": utils.SCAN_STATUS_FAILED, + "cancelled": utils.SCAN_STATUS_CANCELLED, "benchmark_types": benchmarkTypes, }) if err != nil { diff --git a/deepfence_server/model/cloud_node.go b/deepfence_server/model/cloud_node.go index d3eb521405..cf894a544e 100644 --- a/deepfence_server/model/cloud_node.go +++ b/deepfence_server/model/cloud_node.go @@ -137,10 +137,11 @@ type CloudComplianceBenchmark struct { } type CloudComplianceScanDetails struct { - ScanId string `json:"scan_id"` - ScanTypes []string `json:"scan_types"` - AccountId string `json:"account_id"` - Benchmarks []CloudComplianceBenchmark `json:"benchmarks"` + ScanId string `json:"scan_id"` + ScanTypes []string `json:"scan_types"` + AccountId string `json:"account_id"` + Benchmarks []CloudComplianceBenchmark `json:"benchmarks"` + StopRequested bool `json:"stop_requested"` } type CloudNodeCloudtrailTrail struct { diff --git a/deepfence_server/reporters/scan/scan_reporters.go b/deepfence_server/reporters/scan/scan_reporters.go index 4fa79a69a8..86846cea96 100644 --- a/deepfence_server/reporters/scan/scan_reporters.go +++ b/deepfence_server/reporters/scan/scan_reporters.go @@ -537,7 +537,7 @@ func GetCloudCompliancePendingScansList(ctx context.Context, scanType utils.Neo4 return model.CloudComplianceScanListResp{}, err } - session := driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeRead}) + session := driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}) if err != nil { return model.CloudComplianceScanListResp{}, err } @@ -563,9 +563,56 @@ func GetCloudCompliancePendingScansList(ctx context.Context, scanType utils.Neo4 return model.CloudComplianceScanListResp{}, err } - return model.CloudComplianceScanListResp{ - ScansInfo: extractStatusesWithBenchmarks(recs), - }, nil + scansInfo := extractStatusesWithBenchmarks(recs) + + //Get the list of stopping scans + { + res, err := tx.Run(` + MATCH (m:`+string(scanType)+`) -[:SCANNED]-> (n:CloudNode{node_id: $node_id}) + WHERE m.status=$cancel_pending + SET m.status = $cancelling, m.updated_at = TIMESTAMP() + WITH m,n + RETURN m.node_id, m.status, m.status_message, + n.node_id, m.updated_at, n.node_name ORDER BY m.updated_at`, + map[string]interface{}{"node_id": nodeId, + "cancel_pending": utils.SCAN_STATUS_CANCEL_PENDING, + "cancelling": utils.SCAN_STATUS_CANCELLING}) + if err != nil { + log.Info().Msgf("Failed to get stopping scan list for node:%s, error is:%v", nodeId, err) + } else { + recs, err := res.Collect() + if err != nil { + return model.CloudComplianceScanListResp{}, err + } + + //_ = extractStatusesWithBenchmarks(recs, true) + stoppingScansInfo := make([]model.ComplianceScanInfo, 0, len(recs)) + for _, rec := range recs { + tmp := model.ComplianceScanInfo{ + ScanInfo: model.ScanInfo{ + ScanId: rec.Values[0].(string), + Status: rec.Values[1].(string), + StatusMessage: rec.Values[2].(string), + NodeId: rec.Values[3].(string), + NodeType: controls.ResourceTypeToString(controls.CloudAccount), + UpdatedAt: rec.Values[4].(int64), + NodeName: rec.Values[5].(string), + }, + BenchmarkTypes: nil, + } + stoppingScansInfo = append(stoppingScansInfo, tmp) + } + + if len(stoppingScansInfo) != 0 { + scansInfo = append(scansInfo, stoppingScansInfo...) + } + } + } + + err = tx.Commit() + pendScanResp := model.CloudComplianceScanListResp{ScansInfo: scansInfo} + + return pendScanResp, err } func GetScanResultDiff[T any](ctx context.Context, scan_type utils.Neo4jScanType, baseScanID, compareToScanID string, ff reporters.FieldsFilters, fw model.FetchWindow) ([]T, error) { diff --git a/deepfence_server/reporters/scan/scan_result_actions.go b/deepfence_server/reporters/scan/scan_result_actions.go index 96bccc16b0..7ab20a04bd 100644 --- a/deepfence_server/reporters/scan/scan_result_actions.go +++ b/deepfence_server/reporters/scan/scan_result_actions.go @@ -223,6 +223,39 @@ func DeleteScan(ctx context.Context, scanType utils.Neo4jScanType, scanId string return nil } +func StopCloudComplianceScan(ctx context.Context, scanType, scanId string) error { + + driver, err := directory.Neo4jClient(ctx) + if err != nil { + return err + } + session := driver.NewSession(neo4j.SessionConfig{AccessMode: neo4j.AccessModeWrite}) + defer session.Close() + + tx, err := session.BeginTransaction(neo4j.WithTxTimeout(15 * time.Second)) + if err != nil { + return err + } + defer tx.Close() + + query := `MATCH (n:%s) -[:SCANNED]-> () + WHERE n.node_id = $scan_id + AND n.status = $in_progress + SET n.status=$cancel_pending` + + if _, err = tx.Run(fmt.Sprintf(query, scanType), + map[string]interface{}{ + "scan_id": scanId, + "in_progress": utils.SCAN_STATUS_INPROGRESS, + "cancel_pending": utils.SCAN_STATUS_CANCEL_PENDING, + }); err != nil { + log.Error().Msgf("StopCloudComplianceScan: Error in setting the state in neo4j: %v", err) + return err + } + + return tx.Commit() +} + func StopScan(ctx context.Context, scanType, scanId string) error { driver, err := directory.Neo4jClient(ctx) @@ -241,14 +274,14 @@ func StopScan(ctx context.Context, scanType, scanId string) error { query := `MATCH (n:%s) -[:SCANNED]-> () WHERE n.node_id = $scan_id AND n.status IN [$starting, $in_progress] - SET n.status=$cancelling` + SET n.status=$cancel_pending` if _, err = tx.Run(fmt.Sprintf(query, scanType), map[string]interface{}{ - "scan_id": scanId, - "starting": utils.SCAN_STATUS_STARTING, - "in_progress": utils.SCAN_STATUS_INPROGRESS, - "cancelling": utils.SCAN_STATUS_CANCEL_PENDING, + "scan_id": scanId, + "starting": utils.SCAN_STATUS_STARTING, + "in_progress": utils.SCAN_STATUS_INPROGRESS, + "cancel_pending": utils.SCAN_STATUS_CANCEL_PENDING, }); err != nil { log.Error().Msgf("StopScan: Error in setting the state in neo4j: %v", err) return err diff --git a/deepfence_utils/controls/agent.go b/deepfence_utils/controls/agent.go index 9b7d2cc1e4..f7cab1ac22 100644 --- a/deepfence_utils/controls/agent.go +++ b/deepfence_utils/controls/agent.go @@ -18,6 +18,8 @@ const ( UpgradeAgentPlugin StopSecretScan StopMalwareScan + StopVulnerabilityScan + StopComplianceScan ) type ScanResource int @@ -118,17 +120,10 @@ type StartMalwareScanRequest struct { BinArgs map[string]string `json:"bin_args" required:"true"` } -type StopSecretScanRequest struct { - NodeId string `json:"node_id" required:"true"` - NodeType ScanResource `json:"node_type" required:"true"` - BinArgs map[string]string `json:"bin_args" required:"true"` -} - -type StopMalwareScanRequest struct { - NodeId string `json:"node_id" required:"true"` - NodeType ScanResource `json:"node_type" required:"true"` - BinArgs map[string]string `json:"bin_args" required:"true"` -} +type StopSecretScanRequest StartSecretScanRequest +type StopMalwareScanRequest StartSecretScanRequest +type StopVulnerabilityScanRequest StartSecretScanRequest +type StopComplianceScanRequest StartSecretScanRequest type SendAgentDiagnosticLogsRequest struct { NodeId string `json:"node_id" required:"true"` @@ -170,3 +165,23 @@ type AgentControls struct { func (ac AgentControls) ToBytes() ([]byte, error) { return json.Marshal(ac) } + +func GetBinArgs(T interface{}) map[string]string { + switch T.(type) { + case StartVulnerabilityScanRequest: + return T.(StartVulnerabilityScanRequest).BinArgs + case StartSecretScanRequest: + return T.(StartSecretScanRequest).BinArgs + case StartComplianceScanRequest: + return T.(StartComplianceScanRequest).BinArgs + case StartMalwareScanRequest: + return T.(StartMalwareScanRequest).BinArgs + case StopSecretScanRequest: + return T.(StopSecretScanRequest).BinArgs + case StopMalwareScanRequest: + return T.(StopVulnerabilityScanRequest).BinArgs + case StopVulnerabilityScanRequest: + return T.(StopVulnerabilityScanRequest).BinArgs + } + return nil +} diff --git a/deepfence_utils/utils/constants.go b/deepfence_utils/utils/constants.go index 22d95d60b7..99c0c6898d 100644 --- a/deepfence_utils/utils/constants.go +++ b/deepfence_utils/utils/constants.go @@ -48,6 +48,7 @@ const ( LinkNodesTask = "link_nodes" StopSecretScanTask = "task_stop_secret_scan" StopMalwareScanTask = "task_stop_malware_scan" + StopVulnerabilityScanTask = "task_stop_vulnerability_scan" ) const ( @@ -199,6 +200,7 @@ var Tasks = []string{ LinkNodesTask, StopSecretScanTask, StopMalwareScanTask, + StopVulnerabilityScanTask, } type ReportType string diff --git a/deepfence_worker/controls/controls.go b/deepfence_worker/controls/controls.go index f51b2a0c52..c8d0c49f6f 100644 --- a/deepfence_worker/controls/controls.go +++ b/deepfence_worker/controls/controls.go @@ -19,7 +19,8 @@ var controls_guard sync.RWMutex func RegisterControl[T ctl.StartVulnerabilityScanRequest | ctl.StartSecretScanRequest | ctl.StartComplianceScanRequest | ctl.StartMalwareScanRequest | ctl.StartAgentUpgradeRequest | ctl.StopSecretScanRequest | - ctl.StopMalwareScanRequest](id ctl.ActionID, callback func(namespace string, req T) error) error { + ctl.StopMalwareScanRequest | ctl.StopVulnerabilityScanRequest](id ctl.ActionID, + callback func(namespace string, req T) error) error { controls_guard.Lock() defer controls_guard.Unlock() @@ -54,101 +55,72 @@ func init() { func ConsoleActionSetup(pub *kafka.Publisher) error { // for vulnerability scan err := RegisterControl(ctl.StartVulnerabilityScan, - func(namespace string, req ctl.StartVulnerabilityScanRequest) error { - metadata := map[string]string{directory.NamespaceKey: namespace} - log.Info().Msgf("payload: %+v", req.BinArgs) - data, err := json.Marshal(req.BinArgs) - if err != nil { - log.Error().Msg(err.Error()) - return err - } - if err := utils.PublishNewJob(pub, metadata, sdkUtils.GenerateSBOMTask, data); err != nil { - log.Error().Msg(err.Error()) - return err - } - return nil - }) + GetRegisterControlFunc[ctl.StartVulnerabilityScanRequest](pub, + sdkUtils.GenerateSBOMTask)) if err != nil { return err } // for secret scan err = RegisterControl(ctl.StartSecretScan, - func(namespace string, req ctl.StartSecretScanRequest) error { - metadata := map[string]string{directory.NamespaceKey: namespace} - log.Info().Msgf("payload: %+v", req.BinArgs) - data, err := json.Marshal(req.BinArgs) - if err != nil { - log.Error().Msg(err.Error()) - return err - } - if err := utils.PublishNewJob(pub, metadata, sdkUtils.SecretScanTask, data); err != nil { - log.Error().Msg(err.Error()) - return err - } - return nil - }) + GetRegisterControlFunc[ctl.StartSecretScanRequest](pub, + sdkUtils.SecretScanTask)) if err != nil { return err } - err = RegisterControl(ctl.StopSecretScan, - func(namespace string, req ctl.StopSecretScanRequest) error { - metadata := map[string]string{directory.NamespaceKey: namespace} - log.Info().Msgf("StopSecretScan payload: %+v", req.BinArgs) - data, err := json.Marshal(req.BinArgs) - if err != nil { - log.Error().Msg(err.Error()) - return err - } - if err := utils.PublishNewJob(pub, metadata, sdkUtils.StopSecretScanTask, data); err != nil { - log.Error().Msg(err.Error()) - return err - } - return nil - }) + // for malware scan + err = RegisterControl(ctl.StartMalwareScan, + GetRegisterControlFunc[ctl.StartMalwareScanRequest](pub, + sdkUtils.MalwareScanTask)) if err != nil { return err } - // for malware scan - err = RegisterControl(ctl.StartMalwareScan, - func(namespace string, req ctl.StartMalwareScanRequest) error { - metadata := map[string]string{directory.NamespaceKey: namespace} - log.Info().Msgf("payload: %+v", req.BinArgs) - data, err := json.Marshal(req.BinArgs) - if err != nil { - log.Error().Msg(err.Error()) - return err - } - if err := utils.PublishNewJob(pub, metadata, sdkUtils.MalwareScanTask, data); err != nil { - log.Error().Msg(err.Error()) - return err - } - return nil - }) + err = RegisterControl(ctl.StopSecretScan, + GetRegisterControlFunc[ctl.StopSecretScanRequest](pub, + sdkUtils.StopSecretScanTask)) if err != nil { return err } err = RegisterControl(ctl.StopMalwareScan, - func(namespace string, req ctl.StopMalwareScanRequest) error { - metadata := map[string]string{directory.NamespaceKey: namespace} - log.Info().Msgf("StopMalwareScan payload: %+v", req.BinArgs) - data, err := json.Marshal(req.BinArgs) - if err != nil { - log.Error().Msg(err.Error()) - return err - } - if err := utils.PublishNewJob(pub, metadata, sdkUtils.StopMalwareScanTask, data); err != nil { - log.Error().Msg(err.Error()) - return err - } - return nil - }) + GetRegisterControlFunc[ctl.StopMalwareScanRequest](pub, + sdkUtils.StopMalwareScanTask)) + if err != nil { + return err + } + + err = RegisterControl(ctl.StopVulnerabilityScan, + GetRegisterControlFunc[ctl.StopVulnerabilityScanRequest](pub, + sdkUtils.StopVulnerabilityScanTask)) if err != nil { return err } return nil } + +func GetRegisterControlFunc[T ctl.StartVulnerabilityScanRequest | ctl.StartSecretScanRequest | + ctl.StartComplianceScanRequest | ctl.StartMalwareScanRequest | + ctl.StopSecretScanRequest | ctl.StopMalwareScanRequest | + ctl.StopVulnerabilityScanRequest](pub *kafka.Publisher, + task string) func(namespace string, req T) error { + + controlFunc := func(namespace string, req T) error { + metadata := map[string]string{directory.NamespaceKey: namespace} + BinArgs := ctl.GetBinArgs(req) + log.Info().Msgf("%s payload: %+v", task, BinArgs) + data, err := json.Marshal(BinArgs) + if err != nil { + log.Error().Msg(err.Error()) + return err + } + if err := utils.PublishNewJob(pub, metadata, task, data); err != nil { + log.Error().Msg(err.Error()) + return err + } + return nil + } + return controlFunc +} diff --git a/deepfence_worker/ingesters/common.go b/deepfence_worker/ingesters/common.go index f49f0c26eb..1a6048a441 100644 --- a/deepfence_worker/ingesters/common.go +++ b/deepfence_worker/ingesters/common.go @@ -36,22 +36,48 @@ func CommitFuncStatus[Status any](ts utils.Neo4jScanType) func(ns string, data [ } defer tx.Close() - query := "" + var in_progress_query, others_query string switch ts { default: - query = ` + in_progress_query = ` UNWIND $batch as row - MERGE (n:` + string(ts) + `{node_id: row.scan_id}) + MATCH (n:` + string(ts) + `{node_id: row.scan_id}) + WHERE NOT n.status IN $cancel_states SET n.status = row.scan_status, n.status_message = row.scan_message, n.updated_at = TIMESTAMP() WITH n OPTIONAL MATCH (n) -[:DETECTED]- (m) WITH n, count(m) as count MATCH (n) -[:SCANNED]- (r) SET r.` + ingestersUtil.ScanCountField[ts] + `=count, r.` + ingestersUtil.ScanStatusField[ts] + `=n.status, r.` + ingestersUtil.LatestScanIdField[ts] + `=n.node_id` + + others_query = ` + UNWIND $batch as row + MATCH (n:` + string(ts) + `{node_id: row.scan_id}) + SET n.status = row.scan_status, n.status_message = row.scan_message, n.updated_at = TIMESTAMP() + WITH n + OPTIONAL MATCH (n) -[:DETECTED]- (m) + WITH n, count(m) as count + MATCH (n) -[:SCANNED]- (r) + SET r.` + ingestersUtil.ScanCountField[ts] + `=count, r.` + ingestersUtil.ScanStatusField[ts] + `=n.status, r.` + ingestersUtil.LatestScanIdField[ts] + `=n.node_id` + case utils.NEO4J_CLOUD_COMPLIANCE_SCAN: - query = ` + in_progress_query = ` + UNWIND $batch as row + MATCH (n:` + string(ts) + `{node_id: row.scan_id}) + WHERE NOT n.status IN $cancel_states + SET n.status = row.scan_status, n.status_message = row.scan_message, n.updated_at = TIMESTAMP() + WITH n + OPTIONAL MATCH (n) -[:DETECTED]- (m) + WITH n, count(m) as total_count + OPTIONAL MATCH (n) -[:DETECTED]- (m) + WITH n, total_count, m.resource as arn, count(m) as count + OPTIONAL MATCH (n) -[:SCANNED]- (cn) -[:OWNS]- (cr:CloudResource{arn: arn}) + SET cn.` + ingestersUtil.ScanCountField[ts] + `=total_count, cn.` + ingestersUtil.ScanStatusField[ts] + `=n.status, cn.` + ingestersUtil.LatestScanIdField[ts] + `=n.node_id + SET cr.` + ingestersUtil.ScanCountField[ts] + `=count, cr.` + ingestersUtil.ScanStatusField[ts] + `=n.status, cr.` + ingestersUtil.LatestScanIdField[ts] + `=n.node_id` + + others_query = ` UNWIND $batch as row - MERGE (n:` + string(ts) + `{node_id: row.scan_id}) + MATCH (n:` + string(ts) + `{node_id: row.scan_id}) SET n.status = row.scan_status, n.status_message = row.scan_message, n.updated_at = TIMESTAMP() WITH n OPTIONAL MATCH (n) -[:DETECTED]- (m) @@ -64,7 +90,15 @@ func CommitFuncStatus[Status any](ts utils.Neo4jScanType) func(ns string, data [ } recordMap := statusesToMaps(data) - if _, err = tx.Run(query, map[string]interface{}{"batch": statusesToMaps(data)}); err != nil { + in_progress, others := splitInprogressStatus(recordMap) + if _, err = tx.Run(in_progress_query, map[string]interface{}{ + "batch": in_progress, + "cancel_states": []string{utils.SCAN_STATUS_CANCELLING, utils.SCAN_STATUS_CANCEL_PENDING}}); err != nil { + log.Error().Msgf("Error while updating scan status: %+v", err) + return err + } + + if _, err = tx.Run(others_query, map[string]interface{}{"batch": others}); err != nil { log.Error().Msgf("Error while updating scan status: %+v", err) return err } @@ -124,7 +158,8 @@ func statusesToMaps[T any](data []T) []map[string]interface{} { } else { old_status := old["scan_status"].(string) if new_status != old_status { - if new_status == utils.SCAN_STATUS_SUCCESS || new_status == utils.SCAN_STATUS_FAILED { + if new_status == utils.SCAN_STATUS_SUCCESS || + new_status == utils.SCAN_STATUS_FAILED || new_status == utils.SCAN_STATUS_CANCELLED { statusBuff[scan_id] = new } } @@ -138,6 +173,20 @@ func statusesToMaps[T any](data []T) []map[string]interface{} { return statuses } +func splitInprogressStatus(data []map[string]interface{}) ([]map[string]interface{}, []map[string]interface{}) { + in_progress := []map[string]interface{}{} + others := []map[string]interface{}{} + + for i := range data { + if data[i]["scan_status"].(string) == utils.SCAN_STATUS_INPROGRESS { + in_progress = append(in_progress, data[i]) + } else { + others = append(others, data[i]) + } + } + return in_progress, others +} + func ToMap[T any](data T) map[string]interface{} { out, err := json.Marshal(data) if err != nil { diff --git a/deepfence_worker/tasks/sbom/generate_sbom.go b/deepfence_worker/tasks/sbom/generate_sbom.go index c1acd0bf2c..8b63193727 100644 --- a/deepfence_worker/tasks/sbom/generate_sbom.go +++ b/deepfence_worker/tasks/sbom/generate_sbom.go @@ -3,6 +3,7 @@ package sbom import ( "bytes" "compress/gzip" + "context" "encoding/json" "os" "path" @@ -25,6 +26,7 @@ import ( var ( syftBin = "syft" + scanMap = sync.Map{} ) type SbomGenerator struct { @@ -35,6 +37,29 @@ func NewSbomGenerator(ingest chan *kgo.Record) SbomGenerator { return SbomGenerator{ingestC: ingest} } +func StopVulnerabilityScan(msg *message.Message) error { + log.Info().Msgf("StopVulnerabilityScan, uuid: %s payload: %s ", msg.UUID, string(msg.Payload)) + var params utils.SbomParameters + if err := json.Unmarshal(msg.Payload, ¶ms); err != nil { + log.Error().Msgf("StopVulnerabilityScan, error in Unmarshal: %s", err.Error()) + return nil + } + + scanID := params.ScanId + cancelFnObj, found := scanMap.Load(scanID) + logMsg := "" + if found { + cancelFn := cancelFnObj.(context.CancelFunc) + cancelFn() + logMsg = "Stop GenerateSBOM request submitted" + } else { + logMsg = "Failed to Stop scan, SBOM may have already generated or errored out" + } + + log.Info().Msgf("%s, scan_id: %s", logMsg, scanID) + return nil +} + func (s SbomGenerator) GenerateSbom(msg *message.Message) ([]*message.Message, error) { defer cronjobs.ScanWorkloadAllocator.Free() @@ -85,6 +110,14 @@ func (s SbomGenerator) GenerateSbom(msg *message.Message) ([]*message.Message, e return nil, nil } + log.Info().Msgf("Adding scanid to map:%s", params.ScanId) + ctxSbom, cancel := context.WithCancel(context.Background()) + scanMap.Store(params.ScanId, cancel) + defer func(scanId string) { + log.Info().Msgf("Removing scaind from map:%s", scanId) + scanMap.Delete(scanId) + }(params.ScanId) + defer func() { log.Info().Msgf("remove auth directory %s", authFile) if authFile == "" { @@ -127,10 +160,15 @@ func (s SbomGenerator) GenerateSbom(msg *message.Message) ([]*message.Message, e statusChan <- NewSbomScanStatus(params, utils.SCAN_STATUS_INPROGRESS, "", nil) - rawSbom, err := syft.GenerateSBOM(ctx, cfg) + rawSbom, err := syft.GenerateSBOM(ctxSbom, cfg) if err != nil { - log.Error().Msg(err.Error()) - statusChan <- NewSbomScanStatus(params, utils.SCAN_STATUS_FAILED, err.Error(), nil) + if ctxSbom.Err() == context.Canceled { + log.Error().Msgf("Stopping GenerateSBOM as per user request, scanID:%s", params.ScanId) + statusChan <- NewSbomScanStatus(params, utils.SCAN_STATUS_CANCELLED, err.Error(), nil) + } else { + log.Error().Msg(err.Error()) + statusChan <- NewSbomScanStatus(params, utils.SCAN_STATUS_FAILED, err.Error(), nil) + } return nil, nil } diff --git a/deepfence_worker/tasks/sbom/utils.go b/deepfence_worker/tasks/sbom/utils.go index 4e9e9173c4..e8c1541223 100644 --- a/deepfence_worker/tasks/sbom/utils.go +++ b/deepfence_worker/tasks/sbom/utils.go @@ -63,7 +63,9 @@ func StartStatusReporter(title string, statusChan chan SbomScanStatus, ingestC c log.Error().Msgf("error sending scan status: %s, scanid: %s", err.Error(), params.ScanId) } - if status == utils.SCAN_STATUS_SUCCESS || status == utils.SCAN_STATUS_FAILED { + if status == utils.SCAN_STATUS_SUCCESS || + status == utils.SCAN_STATUS_FAILED || + status == utils.SCAN_STATUS_CANCELLED { break loop } case <-ticker.C: diff --git a/deepfence_worker/worker.go b/deepfence_worker/worker.go index a11a5d1d59..ab50190a8a 100644 --- a/deepfence_worker/worker.go +++ b/deepfence_worker/worker.go @@ -339,6 +339,9 @@ func startWorker(wml watermill.LoggerAdapter, cfg config) error { worker.AddNoPublisherHandler(utils.LinkNodesTask, cronjobs.LinkNodes, true) + worker.AddNoPublisherHandler(utils.StopVulnerabilityScanTask, + sbom.StopVulnerabilityScan, false) + go worker.pollHandlers() log.Info().Msg("Starting the consumer")