From 903aea7d76cb5a0f94db9af3e8da566b929fe46f Mon Sep 17 00:00:00 2001 From: Corey Ogburn Date: Tue, 24 Sep 2024 10:31:36 -0600 Subject: [PATCH 1/2] Airgap Check for AI Summaries If the server is configured with AirgapEnabled = true, then the call to RefreshAiSummaries will log a debug statement, do nothing, and return no error. Otherwise the repo will be updated as usual. --- server/modules/detections/ai_summary.go | 6 ++++++ server/modules/detections/ai_summary_test.go | 11 ++++++++--- server/modules/detections/mock/mock_ailoader.go | 14 ++++++++++++++ server/modules/elastalert/elastalert.go | 4 ++++ server/modules/strelka/strelka.go | 4 ++++ server/modules/suricata/suricata.go | 4 ++++ 6 files changed, 40 insertions(+), 3 deletions(-) diff --git a/server/modules/detections/ai_summary.go b/server/modules/detections/ai_summary.go index a9e24737..44d7db6a 100644 --- a/server/modules/detections/ai_summary.go +++ b/server/modules/detections/ai_summary.go @@ -20,11 +20,17 @@ var lastSuccessfulAiUpdate time.Time type AiLoader interface { LoadAuxiliaryData(summaries []*model.AiSummary) error + IsAirgapped() bool } //go:generate mockgen -destination mock/mock_ailoader.go -package mock . AiLoader func RefreshAiSummaries(eng AiLoader, lang model.SigLanguage, isRunning *bool, aiRepoPath string, aiRepoUrl string, aiRepoBranch string, logger *log.Entry, iom IOManager) error { + if eng.IsAirgapped() { + logger.Debug("skipping AI summary update because airgap is enabled") + return nil + } + err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) if err != nil { if errors.Is(err, ErrModuleStopped) { diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go index 1ba471e7..fab9e98e 100644 --- a/server/modules/detections/ai_summary_test.go +++ b/server/modules/detections/ai_summary_test.go @@ -25,7 +25,14 @@ func TestRefreshAiSummaries(t *testing.T) { iom := mock.NewMockIOManager(ctrl) loader := mock.NewMockAiLoader(ctrl) + logger := log.WithField("test", true) + + loader.EXPECT().IsAirgapped().Return(true) + + err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) + assert.NoError(t, err) + loader.EXPECT().IsAirgapped().Return(false) iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil) iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil) iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) @@ -54,10 +61,8 @@ func TestRefreshAiSummaries(t *testing.T) { return nil }) - logger := log.WithField("test", true) - lastSuccessfulAiUpdate = time.Time{} - err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) + err = RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) assert.NoError(t, err) } diff --git a/server/modules/detections/mock/mock_ailoader.go b/server/modules/detections/mock/mock_ailoader.go index a93acbf2..ed0bfc4e 100644 --- a/server/modules/detections/mock/mock_ailoader.go +++ b/server/modules/detections/mock/mock_ailoader.go @@ -38,6 +38,20 @@ func (m *MockAiLoader) EXPECT() *MockAiLoaderMockRecorder { return m.recorder } +// IsAirgapped mocks base method. +func (m *MockAiLoader) IsAirgapped() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsAirgapped") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsAirgapped indicates an expected call of IsAirgapped. +func (mr *MockAiLoaderMockRecorder) IsAirgapped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAirgapped", reflect.TypeOf((*MockAiLoader)(nil).IsAirgapped)) +} + // LoadAuxiliaryData mocks base method. func (m *MockAiLoader) LoadAuxiliaryData(arg0 []*model.AiSummary) error { m.ctrl.T.Helper() diff --git a/server/modules/elastalert/elastalert.go b/server/modules/elastalert/elastalert.go index 6b543b75..4185c6c6 100644 --- a/server/modules/elastalert/elastalert.go +++ b/server/modules/elastalert/elastalert.go @@ -1556,6 +1556,10 @@ func (e *ElastAlertEngine) DuplicateDetection(ctx context.Context, detection *mo return det, nil } +func (e *ElastAlertEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *ElastAlertEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { diff --git a/server/modules/strelka/strelka.go b/server/modules/strelka/strelka.go index 4315b5e2..111a1124 100644 --- a/server/modules/strelka/strelka.go +++ b/server/modules/strelka/strelka.go @@ -1136,6 +1136,10 @@ func (e *StrelkaEngine) DuplicateDetection(ctx context.Context, detection *model return det, nil } +func (e *StrelkaEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *StrelkaEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { diff --git a/server/modules/suricata/suricata.go b/server/modules/suricata/suricata.go index a44aaa11..689ecf02 100644 --- a/server/modules/suricata/suricata.go +++ b/server/modules/suricata/suricata.go @@ -1746,6 +1746,10 @@ func (e *SuricataEngine) DuplicateDetection(ctx context.Context, detection *mode return det, nil } +func (e *SuricataEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *SuricataEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { From 9bc8efc407547fbca543eb89fd7588e17175298d Mon Sep 17 00:00:00 2001 From: Corey Ogburn Date: Tue, 24 Sep 2024 11:16:32 -0600 Subject: [PATCH 2/2] Stronger assertions that RefreshAiSummaries does nothing if AirgapEnabled Removed all unnecessary parameters and added assertions around the log statement to provide better evidence that having `AirgapEnabled = true` will not try to clone, pull, or otherwise update the AI summaries repo. Discovered that we were importing github.com/tj/assert, which is a fork of github.com/stretchr/testify/assert with a few minor changes. Updating all references to use stretchr's original library for consistency. All tests passing. --- go.mod | 1 - model/custom_ruleset_test.go | 2 +- model/detection_test.go | 3 ++- model/rulerepo_test.go | 2 +- server/modules/detections/ai_summary_test.go | 18 ++++++++++++++---- .../detections/detengine_helpers_test.go | 2 +- server/modules/detections/errortracker_test.go | 2 +- .../modules/detections/integrity_check_test.go | 2 +- server/modules/detections/io_manager_test.go | 3 ++- server/modules/strelka/strelka_test.go | 2 +- .../modules/suricata/migration-2.4.70_test.go | 2 +- util/strings_test.go | 2 +- 12 files changed, 26 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index f2b823c6..8a579fb8 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/pierrec/lz4/v4 v4.1.21 github.com/pkg/errors v0.9.1 github.com/samber/lo v1.47.0 - github.com/tj/assert v0.0.3 go.uber.org/mock v0.4.0 golang.org/x/mod v0.20.0 ) diff --git a/model/custom_ruleset_test.go b/model/custom_ruleset_test.go index 787b71cf..c3c367ac 100644 --- a/model/custom_ruleset_test.go +++ b/model/custom_ruleset_test.go @@ -8,7 +8,7 @@ package model import ( "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestGetCustomRulesetsDefault(t *testing.T) { diff --git a/model/detection_test.go b/model/detection_test.go index 7a546783..ca3d6d2b 100644 --- a/model/detection_test.go +++ b/model/detection_test.go @@ -9,7 +9,8 @@ import ( "testing" "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + + "github.com/stretchr/testify/assert" ) func TestDetectionOverrideValidate(t *testing.T) { diff --git a/model/rulerepo_test.go b/model/rulerepo_test.go index 5313ba35..5d52cc21 100644 --- a/model/rulerepo_test.go +++ b/model/rulerepo_test.go @@ -10,7 +10,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestGetRepos(t *testing.T) { diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go index fab9e98e..c48442e8 100644 --- a/server/modules/detections/ai_summary_test.go +++ b/server/modules/detections/ai_summary_test.go @@ -6,11 +6,12 @@ import ( "testing" "time" - "github.com/apex/log" "github.com/security-onion-solutions/securityonion-soc/model" "github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock" - "github.com/tj/assert" + "github.com/apex/log" + "github.com/apex/log/handlers/memory" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) @@ -25,13 +26,22 @@ func TestRefreshAiSummaries(t *testing.T) { iom := mock.NewMockIOManager(ctrl) loader := mock.NewMockAiLoader(ctrl) - logger := log.WithField("test", true) + + h := memory.New() + lg := &log.Logger{Handler: h, Level: log.DebugLevel} + logger := lg.WithField("test", true) loader.EXPECT().IsAirgapped().Return(true) - err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) + err := RefreshAiSummaries(loader, model.SigLanguage(""), nil, "", "", "", logger, nil) assert.NoError(t, err) + assert.Equal(t, len(h.Entries), 1) + + msg := h.Entries[0] + assert.Equal(t, msg.Message, "skipping AI summary update because airgap is enabled") + assert.Equal(t, msg.Level, log.DebugLevel) + loader.EXPECT().IsAirgapped().Return(false) iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil) iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil) diff --git a/server/modules/detections/detengine_helpers_test.go b/server/modules/detections/detengine_helpers_test.go index ab25ff59..c4dd230e 100644 --- a/server/modules/detections/detengine_helpers_test.go +++ b/server/modules/detections/detengine_helpers_test.go @@ -20,7 +20,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/util" "github.com/go-git/go-git/v5/plumbing/transport" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/server/modules/detections/errortracker_test.go b/server/modules/detections/errortracker_test.go index b86aa9a3..42871ef5 100644 --- a/server/modules/detections/errortracker_test.go +++ b/server/modules/detections/errortracker_test.go @@ -9,7 +9,7 @@ import ( "errors" "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestErrorTracker(t *testing.T) { diff --git a/server/modules/detections/integrity_check_test.go b/server/modules/detections/integrity_check_test.go index cf1e62d3..d94270d2 100644 --- a/server/modules/detections/integrity_check_test.go +++ b/server/modules/detections/integrity_check_test.go @@ -9,7 +9,7 @@ import ( "sort" "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestDiffLists(t *testing.T) { diff --git a/server/modules/detections/io_manager_test.go b/server/modules/detections/io_manager_test.go index 30a1f4e1..2d46980d 100644 --- a/server/modules/detections/io_manager_test.go +++ b/server/modules/detections/io_manager_test.go @@ -10,7 +10,8 @@ import ( "testing" "github.com/security-onion-solutions/securityonion-soc/config" - "github.com/tj/assert" + + "github.com/stretchr/testify/assert" ) func TestBuildHttpClient(t *testing.T) { diff --git a/server/modules/strelka/strelka_test.go b/server/modules/strelka/strelka_test.go index 30668426..2633c1dc 100644 --- a/server/modules/strelka/strelka_test.go +++ b/server/modules/strelka/strelka_test.go @@ -30,7 +30,7 @@ import ( "github.com/apex/log" "github.com/elastic/go-elasticsearch/v8/esutil" "github.com/samber/lo" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/server/modules/suricata/migration-2.4.70_test.go b/server/modules/suricata/migration-2.4.70_test.go index 44a93414..f3687827 100644 --- a/server/modules/suricata/migration-2.4.70_test.go +++ b/server/modules/suricata/migration-2.4.70_test.go @@ -16,7 +16,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock" "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/util/strings_test.go b/util/strings_test.go index 7437fc7d..09b1bad0 100644 --- a/util/strings_test.go +++ b/util/strings_test.go @@ -8,7 +8,7 @@ package util import ( "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestUnquote(t *testing.T) {