diff --git a/server/modules/detections/integrity_check.go b/server/modules/detections/integrity_check.go index 220ab08a..9651e0a2 100644 --- a/server/modules/detections/integrity_check.go +++ b/server/modules/detections/integrity_check.go @@ -18,7 +18,7 @@ var ErrIntCheckerStopped = fmt.Errorf("integrity checker has stopped running") var ErrIntCheckFailed = fmt.Errorf("integrity check failed; discrepancies found") type IntegrityChecked interface { - IntegrityCheck(bool) error + IntegrityCheck(bool) ([]string, []string, error) InterruptSync(forceFull bool, notify bool) IsRunning() bool } @@ -56,7 +56,7 @@ func IntegrityChecker(engName model.EngineName, eng IntegrityChecked, data *Inte continue } - err := eng.IntegrityCheck(true) + _, _, err := eng.IntegrityCheck(true) if err != nil { if err != ErrIntCheckerStopped { failCount++ diff --git a/server/modules/elastalert/elastalert.go b/server/modules/elastalert/elastalert.go index fa471de8..86f573a5 100644 --- a/server/modules/elastalert/elastalert.go +++ b/server/modules/elastalert/elastalert.go @@ -676,7 +676,7 @@ func (e *ElastAlertEngine) startCommunityRuleImport() { }) } - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -790,7 +790,7 @@ func (e *ElastAlertEngine) startCommunityRuleImport() { }) } - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -1568,10 +1568,10 @@ func wrapRule(det *model.Detection, rule string, additionalAlerters []string) (s return string(rawYaml), nil } -func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) error { +func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) (deployedButNotEnabled []string, enabledButNotDeployed []string, err error) { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } logger := log.WithFields(log.Fields{ @@ -1582,7 +1582,7 @@ func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) error { deployed, err := e.getDeployedPublicIds() if err != nil { logger.WithError(err).Error("unable to get deployed publicIds") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } logger.WithField("deployedPublicIdsCount", len(deployed)).Debug("deployed publicIds") @@ -1590,18 +1590,18 @@ func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { logger.Info("integrity checker stopped") - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } ret, err := e.srv.Detectionstore.GetAllDetections(e.srv.Context, model.WithEngine(model.EngineNameElastAlert), model.WithEnabled(true)) if err != nil { logger.WithError(err).Error("unable to query for enabled detections") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } enabled := make([]string, 0, len(ret)) - for _, d := range ret { - enabled = append(enabled, d.PublicID) + for pid := range ret { + enabled = append(enabled, pid) } logger.WithField("enabledDetectionsCount", len(enabled)).Debug("enabled detections") @@ -1609,10 +1609,10 @@ func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { logger.Info("integrity checker stopped") - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } - deployedButNotEnabled, enabledButNotDeployed, _ := detections.DiffLists(deployed, enabled) + deployedButNotEnabled, enabledButNotDeployed, _ = detections.DiffLists(deployed, enabled) logger.WithFields(log.Fields{ "deployedButNotEnabled": deployedButNotEnabled, @@ -1621,12 +1621,12 @@ func (e *ElastAlertEngine) IntegrityCheck(canInterrupt bool) error { if len(deployedButNotEnabled) > 0 || len(enabledButNotDeployed) > 0 { logger.Info("integrity check failed") - return detections.ErrIntCheckFailed + return deployedButNotEnabled, enabledButNotDeployed, detections.ErrIntCheckFailed } logger.Info("integrity check passed") - return nil + return deployedButNotEnabled, enabledButNotDeployed, nil } func (e *ElastAlertEngine) getDeployedPublicIds() (publicIds []string, err error) { diff --git a/server/modules/elastalert/elastalert_test.go b/server/modules/elastalert/elastalert_test.go index 58c5c195..d2f750bc 100644 --- a/server/modules/elastalert/elastalert_test.go +++ b/server/modules/elastalert/elastalert_test.go @@ -27,6 +27,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/model" "github.com/security-onion-solutions/securityonion-soc/module" "github.com/security-onion-solutions/securityonion-soc/server" + servermock "github.com/security-onion-solutions/securityonion-soc/server/mock" "github.com/security-onion-solutions/securityonion-soc/server/modules/detections" "github.com/security-onion-solutions/securityonion-soc/server/modules/elastalert/mock" "github.com/security-onion-solutions/securityonion-soc/util" @@ -1138,3 +1139,137 @@ func TestBuildHttpClient(t *testing.T) { }) } } + +func TestIntegrityCheck(t *testing.T) { + tests := []struct { + Name string + InitMock func(*mock.MockIOManager, *servermock.MockDetectionstore) + DbnE []string + EbnD []string + ExpError error + }{ + { + Name: "No Rules", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + iom.EXPECT().ReadDir("rules/folder").Return([]fs.DirEntry{}, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, opts ...model.GetAllOption) (map[string]*model.Detection, error) { + expected := []string{ + `query AND so_detection.engine:"elastalert"`, + `query AND so_detection.isEnabled:"true"`, + } + + for i, opt := range opts { + value := opt("query", "so_") + assert.Equal(t, expected[i], value) + } + + return map[string]*model.Detection{}, nil + }) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + { + Name: "1 Deployed, 0 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + iom.EXPECT().ReadDir("rules/folder").Return([]fs.DirEntry{ + &MockDirEntry{ + name: "00000000-0000-0000-0000-000000000000.yml", + }, + }, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{}, nil) + }, + DbnE: []string{"00000000-0000-0000-0000-000000000000"}, + EbnD: []string{}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "0 Deployed, 1 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + iom.EXPECT().ReadDir("rules/folder").Return([]fs.DirEntry{}, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "00000000-0000-0000-0000-000000000000": {}, + }, nil) + }, + DbnE: []string{}, + EbnD: []string{"00000000-0000-0000-0000-000000000000"}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Multiple Fail", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + iom.EXPECT().ReadDir("rules/folder").Return([]fs.DirEntry{ + &MockDirEntry{ + name: "00000000-0000-0000-0000-000000000000.yml", + }, + &MockDirEntry{ + name: "11111111-1111-1111-1111-111111111111.yml", + }, + }, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "00000000-0000-0000-0000-000000000000": {}, + "22222222-2222-2222-2222-222222222222": {}, + }, nil) + }, + DbnE: []string{"11111111-1111-1111-1111-111111111111"}, + EbnD: []string{"22222222-2222-2222-2222-222222222222"}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Multiple Success", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + iom.EXPECT().ReadDir("rules/folder").Return([]fs.DirEntry{ + &MockDirEntry{ + name: "00000000-0000-0000-0000-000000000000.yml", + }, + &MockDirEntry{ + name: "11111111-1111-1111-1111-111111111111.yml", + }, + }, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "00000000-0000-0000-0000-000000000000": {}, + "11111111-1111-1111-1111-111111111111": {}, + }, nil) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + detStore := servermock.NewMockDetectionstore(ctrl) + iom := mock.NewMockIOManager(ctrl) + test.InitMock(iom, detStore) + + e := &ElastAlertEngine{ + srv: &server.Server{ + Detectionstore: detStore, + }, + elastAlertRulesFolder: "rules/folder", + IOManager: iom, + } + + DbnE, EbnD, err := e.IntegrityCheck(false) + + if test.ExpError != nil { + assert.Error(t, err) + assert.Equal(t, err, test.ExpError) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, test.DbnE, DbnE) + assert.Equal(t, test.EbnD, EbnD) + }) + } +} diff --git a/server/modules/strelka/strelka.go b/server/modules/strelka/strelka.go index bfd41038..84360dd3 100644 --- a/server/modules/strelka/strelka.go +++ b/server/modules/strelka/strelka.go @@ -422,7 +422,7 @@ func (e *StrelkaEngine) startCommunityRuleImport() { }) } - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -656,7 +656,7 @@ func (e *StrelkaEngine) startCommunityRuleImport() { } } - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -1015,10 +1015,10 @@ func (e *StrelkaEngine) GenerateUnusedPublicId(ctx context.Context) (string, err return "", fmt.Errorf("not implemented") } -func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) error { +func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) (deployedButNotEnabled []string, enabledButNotDeployed []string, err error) { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } logger := log.WithFields(log.Fields{ @@ -1030,13 +1030,13 @@ func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) error { report, err := e.getCompilationReport() if err != nil { logger.WithError(err).Error("unable to get compilation report") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } err = e.verifyCompiledHash(report.CompiledRulesHash) if err != nil { logger.WithError(err).Error("compiled rules hash mismatch, this report is not for the latest compiled rules") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } logger.WithFields(log.Fields{ @@ -1055,35 +1055,35 @@ func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) error { logger.WithField("failedPublicIDs", problemSample).Error("integrity check failed because some rules failed to deploy") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } deployed := getDeployed(report) // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } ret, err := e.srv.Detectionstore.GetAllDetections(e.srv.Context, model.WithEngine(model.EngineNameStrelka), model.WithEnabled(true)) if err != nil { logger.WithError(err).Error("unable to query for enabled detections") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } enabled := make([]string, 0, len(ret)) - for _, d := range ret { - enabled = append(enabled, d.PublicID) + for pid := range ret { + enabled = append(enabled, pid) } logger.WithField("enabledDetectionsCount", len(enabled)).Debug("enabled detections") @@ -1091,10 +1091,10 @@ func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { logger.Info("integrity checker stopped") - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } - deployedButNotEnabled, enabledButNotDeployed, _ := detections.DiffLists(deployed, enabled) + deployedButNotEnabled, enabledButNotDeployed, _ = detections.DiffLists(deployed, enabled) logger.WithFields(log.Fields{ "deployedButNotEnabled": deployedButNotEnabled, @@ -1103,12 +1103,12 @@ func (e *StrelkaEngine) IntegrityCheck(canInterrupt bool) error { if len(deployedButNotEnabled) > 0 || len(enabledButNotDeployed) > 0 { logger.Info("integrity check failed") - return detections.ErrIntCheckFailed + return deployedButNotEnabled, enabledButNotDeployed, detections.ErrIntCheckFailed } logger.Info("integrity check passed") - return nil + return deployedButNotEnabled, enabledButNotDeployed, nil } func (e *StrelkaEngine) getCompilationReport() (*model.CompilationReport, error) { diff --git a/server/modules/strelka/strelka_test.go b/server/modules/strelka/strelka_test.go index 177b778e..ad476e55 100644 --- a/server/modules/strelka/strelka_test.go +++ b/server/modules/strelka/strelka_test.go @@ -7,6 +7,7 @@ package strelka import ( "context" + "encoding/json" "io/fs" "os" "os/exec" @@ -20,6 +21,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/module" "github.com/security-onion-solutions/securityonion-soc/server" servermock "github.com/security-onion-solutions/securityonion-soc/server/mock" + "github.com/security-onion-solutions/securityonion-soc/server/modules/detections" "github.com/security-onion-solutions/securityonion-soc/server/modules/strelka/mock" "github.com/security-onion-solutions/securityonion-soc/util" @@ -788,3 +790,183 @@ func TestVerifyCompiledHash(t *testing.T) { err = eng.verifyCompiledHash("") assert.NoError(t, err) } + +func TestIntegrityCheck(t *testing.T) { + tests := []struct { + Name string + InitMock func(*mock.MockIOManager, *servermock.MockDetectionstore) + DbnE []string + EbnD []string + ExpError error + }{ + { + Name: "No Rules", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{} + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return(nil, os.ErrNotExist) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, opts ...model.GetAllOption) (map[string]*model.Detection, error) { + expected := []string{ + `query AND so_detection.engine:"strelka"`, + `query AND so_detection.isEnabled:"true"`, + } + + for i, opt := range opts { + value := opt("query", "so_") + assert.Equal(t, expected[i], value) + } + + return map[string]*model.Detection{}, nil + }) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + { + Name: "Bad Compilation Report Hash", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "bad hash", + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + }, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Compilation Report Failures", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + Failure: []string{"MyYARARule1", "MyYARARule2", "MyYARARule3", "MyYARARule4", "MyYARARule5", "MyYARARule6"}, + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + }, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "1 Deployed, 0 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + Success: []string{"MyYARARule"}, + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{}, nil) + }, + DbnE: []string{"MyYARARule"}, + EbnD: []string{}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "0 Deployed, 1 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "MyYARARule": {}, + }, nil) + }, + DbnE: []string{}, + EbnD: []string{"MyYARARule"}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Mix and Match Fail", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + Success: []string{"MyYARARule", "MyOtherYARARule"}, + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "MyYARARule": {}, + "AThirdYARARule": {}, + }, nil) + }, + DbnE: []string{"MyOtherYARARule"}, + EbnD: []string{"AThirdYARARule"}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Mix and Match Fail", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) { + report := model.CompilationReport{ + CompiledRulesHash: "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad", + Success: []string{"MyYARARule", "MyOtherYARARule"}, + } + + jsonReport, _ := json.Marshal(report) + iom.EXPECT().ReadFile("/opt/so/state/detections_yara_compilation-total.log").Return(jsonReport, nil) + + iom.EXPECT().ReadFile("/opt/so/saltstack/local/salt/strelka/rules/compiled/rules.compiled").Return([]byte("abc"), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + "MyYARARule": {}, + "MyOtherYARARule": {}, + }, nil) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + detStore := servermock.NewMockDetectionstore(ctrl) + iom := mock.NewMockIOManager(ctrl) + test.InitMock(iom, detStore) + + e := &StrelkaEngine{ + srv: &server.Server{ + Detectionstore: detStore, + }, + IOManager: iom, + } + + DbnE, EbnD, err := e.IntegrityCheck(false) + + if test.ExpError != nil { + assert.Error(t, err) + assert.Equal(t, err, test.ExpError) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, test.DbnE, DbnE) + assert.Equal(t, test.EbnD, EbnD) + }) + } +} diff --git a/server/modules/suricata/suricata.go b/server/modules/suricata/suricata.go index b5f3e442..a6886e6d 100644 --- a/server/modules/suricata/suricata.go +++ b/server/modules/suricata/suricata.go @@ -468,7 +468,7 @@ func (e *SuricataEngine) watchCommunityRules() { checkMigrationsOnce() - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -595,7 +595,7 @@ func (e *SuricataEngine) watchCommunityRules() { }) } - err = e.IntegrityCheck(false) + _, _, err = e.IntegrityCheck(false) e.EngineState.IntegrityFailure = err != nil lastSyncSuccess = util.Ptr(err == nil) @@ -1662,10 +1662,10 @@ func (e *SuricataEngine) ReadCustomRulesets() (detects []*model.Detection, err e return detects, nil } -func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { +func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) (deployedButNotEnabled []string, enabledButNotDeployed []string, err error) { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } logger := log.WithFields(log.Fields{ @@ -1675,34 +1675,34 @@ func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { allSettings, err := e.srv.Configstore.GetSettings(e.srv.Context) if err != nil { - return err + return nil, nil, err } // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { logger.Info("integrity checker stopped") - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } allRules, err := e.ReadFile(e.allRulesFile) if err != nil { logger.WithError(err).WithField("path", e.allRulesFile).Error("unable to read all.rules file") - return err + return nil, nil, err } disabled := settingByID(allSettings, "idstools.sids.disabled") if disabled == nil { - return fmt.Errorf("unable to find disabled setting") + return nil, nil, fmt.Errorf("unable to find disabled setting") } modify := settingByID(allSettings, "idstools.sids.modify") if modify == nil { - return fmt.Errorf("unable to find modify setting") + return nil, nil, fmt.Errorf("unable to find modify setting") } // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } // unpack settings into lines/indices @@ -1722,7 +1722,7 @@ func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } deployed := consolidateEnabled(rulesIndex, disabledIndex) @@ -1731,23 +1731,23 @@ func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } ret, err := e.srv.Detectionstore.GetAllDetections(e.srv.Context, model.WithEngine(model.EngineNameSuricata), model.WithEnabled(true)) if err != nil { logger.WithError(err).Error("unable to query for enabled detections") - return detections.ErrIntCheckFailed + return nil, nil, detections.ErrIntCheckFailed } // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } enabled := make([]string, 0, len(ret)) - for _, d := range ret { - enabled = append(enabled, d.PublicID) + for pid := range ret { + enabled = append(enabled, pid) } logger.WithField("enabledDetectionsCount", len(enabled)).Debug("enabled detections") @@ -1755,10 +1755,10 @@ func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { // escape if canInterrupt && !e.IntegrityCheckerData.IsRunning { logger.Info("integrity checker stopped") - return detections.ErrIntCheckerStopped + return nil, nil, detections.ErrIntCheckerStopped } - deployedButNotEnabled, enabledButNotDeployed, _ := detections.DiffLists(deployed, enabled) + deployedButNotEnabled, enabledButNotDeployed, _ = detections.DiffLists(deployed, enabled) logger.WithFields(log.Fields{ "deployedButNotEnabled": deployedButNotEnabled, @@ -1767,12 +1767,12 @@ func (e *SuricataEngine) IntegrityCheck(canInterrupt bool) error { if len(deployedButNotEnabled) > 0 || len(enabledButNotDeployed) > 0 { logger.Info("integrity check failed") - return detections.ErrIntCheckFailed + return deployedButNotEnabled, enabledButNotDeployed, detections.ErrIntCheckFailed } logger.Info("integrity check passed") - return nil + return deployedButNotEnabled, enabledButNotDeployed, nil } func consolidateEnabled(rulesIndex map[string]int, disabledIndex map[string]int) (pids []string) { diff --git a/server/modules/suricata/suricata_test.go b/server/modules/suricata/suricata_test.go index 90041ca8..e2c1ced1 100644 --- a/server/modules/suricata/suricata_test.go +++ b/server/modules/suricata/suricata_test.go @@ -17,6 +17,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/module" "github.com/security-onion-solutions/securityonion-soc/server" servermock "github.com/security-onion-solutions/securityonion-soc/server/mock" + "github.com/security-onion-solutions/securityonion-soc/server/modules/detections" "github.com/security-onion-solutions/securityonion-soc/server/modules/suricata/mock" "github.com/security-onion-solutions/securityonion-soc/util" "github.com/security-onion-solutions/securityonion-soc/web" @@ -1741,3 +1742,149 @@ func TestReadCustomRulesets(t *testing.T) { }) } } + +func TestIntegrityCheck(t *testing.T) { + // the configstore only needs to specify disabled and modify + tests := []struct { + Name string + InitMock func(*mock.MockIOManager, *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) + DbnE []string + EbnD []string + ExpError error + }{ + { + Name: "No Rules", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte{}, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, opts ...model.GetAllOption) (map[string]*model.Detection, error) { + expected := []string{ + `query AND so_detection.engine:"suricata"`, + `query AND so_detection.isEnabled:"true"`, + } + + for i, opt := range opts { + value := opt("query", "so_") + assert.Equal(t, expected[i], value) + } + + return map[string]*model.Detection{}, nil + }) + + return server.NewMemConfigStore(emptySettings()) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + { + Name: "1 Deployed, 0 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte(SimpleRule), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{}, nil) + + return server.NewMemConfigStore(emptySettings()) + }, + DbnE: []string{SimpleRuleSID}, + EbnD: []string{}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "0 Deployed, 1 Enabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte{}, nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + SimpleRuleSID: {}, + }, nil) + + return server.NewMemConfigStore(emptySettings()) + }, + DbnE: []string{}, + EbnD: []string{SimpleRuleSID}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Deployed As Disabled", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte(SimpleRule+"\n"+FlowbitsRuleA), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{}, nil) + + return server.NewMemConfigStore([]*model.Setting{ + {Id: "idstools.sids.disabled", Value: SimpleRuleSID}, + {Id: "idstools.sids.modify", Value: FlowbitsRuleASID + " " + modifyFromTo}, + }) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + { + Name: "Mix and Match Fail", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte(SimpleRule+"\n"+FlowbitsRuleA), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + SimpleRuleSID: {}, + FlowbitsRuleBSID: {}, + }, nil) + + return server.NewMemConfigStore(emptySettings()) + }, + DbnE: []string{FlowbitsRuleASID}, + EbnD: []string{FlowbitsRuleBSID}, + ExpError: detections.ErrIntCheckFailed, + }, + { + Name: "Mix and Match Success", + InitMock: func(iom *mock.MockIOManager, detStore *servermock.MockDetectionstore) (cfgStore *server.MemConfigStore) { + iom.EXPECT().ReadFile("allrules").Return([]byte(SimpleRule+"\n"+FlowbitsRuleA+"\n"+FlowbitsRuleB), nil) + + detStore.EXPECT().GetAllDetections(gomock.Any(), gomock.Any()).Return(map[string]*model.Detection{ + SimpleRuleSID: {}, + FlowbitsRuleASID: {}, + }, nil) + + return server.NewMemConfigStore([]*model.Setting{ + {Id: "idstools.sids.disabled"}, + {Id: "idstools.sids.modify", Value: FlowbitsRuleBSID + " " + modifyFromTo}, + }) + }, + DbnE: []string{}, + EbnD: []string{}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + detStore := servermock.NewMockDetectionstore(ctrl) + iom := mock.NewMockIOManager(ctrl) + cfgStore := test.InitMock(iom, detStore) + + e := &SuricataEngine{ + srv: &server.Server{ + Configstore: cfgStore, + Detectionstore: detStore, + }, + allRulesFile: "allrules", + IOManager: iom, + } + + DbnE, EbnD, err := e.IntegrityCheck(false) + + if test.ExpError != nil { + assert.Error(t, err) + assert.Equal(t, err, test.ExpError) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, test.DbnE, DbnE) + assert.Equal(t, test.EbnD, EbnD) + }) + } +}