diff --git a/go.mod b/go.mod index f2b823c6..8a579fb8 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( github.com/pierrec/lz4/v4 v4.1.21 github.com/pkg/errors v0.9.1 github.com/samber/lo v1.47.0 - github.com/tj/assert v0.0.3 go.uber.org/mock v0.4.0 golang.org/x/mod v0.20.0 ) diff --git a/model/custom_ruleset_test.go b/model/custom_ruleset_test.go index 787b71cf..c3c367ac 100644 --- a/model/custom_ruleset_test.go +++ b/model/custom_ruleset_test.go @@ -8,7 +8,7 @@ package model import ( "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestGetCustomRulesetsDefault(t *testing.T) { diff --git a/model/detection_test.go b/model/detection_test.go index 7a546783..ca3d6d2b 100644 --- a/model/detection_test.go +++ b/model/detection_test.go @@ -9,7 +9,8 @@ import ( "testing" "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + + "github.com/stretchr/testify/assert" ) func TestDetectionOverrideValidate(t *testing.T) { diff --git a/model/rulerepo_test.go b/model/rulerepo_test.go index 5313ba35..5d52cc21 100644 --- a/model/rulerepo_test.go +++ b/model/rulerepo_test.go @@ -10,7 +10,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestGetRepos(t *testing.T) { diff --git a/server/modules/detections/ai_summary.go b/server/modules/detections/ai_summary.go index a9e24737..44d7db6a 100644 --- a/server/modules/detections/ai_summary.go +++ b/server/modules/detections/ai_summary.go @@ -20,11 +20,17 @@ var lastSuccessfulAiUpdate time.Time type AiLoader interface { LoadAuxiliaryData(summaries []*model.AiSummary) error + IsAirgapped() bool } //go:generate mockgen -destination mock/mock_ailoader.go -package mock . AiLoader func RefreshAiSummaries(eng AiLoader, lang model.SigLanguage, isRunning *bool, aiRepoPath string, aiRepoUrl string, aiRepoBranch string, logger *log.Entry, iom IOManager) error { + if eng.IsAirgapped() { + logger.Debug("skipping AI summary update because airgap is enabled") + return nil + } + err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) if err != nil { if errors.Is(err, ErrModuleStopped) { diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go index 1ba471e7..c48442e8 100644 --- a/server/modules/detections/ai_summary_test.go +++ b/server/modules/detections/ai_summary_test.go @@ -6,11 +6,12 @@ import ( "testing" "time" - "github.com/apex/log" "github.com/security-onion-solutions/securityonion-soc/model" "github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock" - "github.com/tj/assert" + "github.com/apex/log" + "github.com/apex/log/handlers/memory" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) @@ -26,6 +27,22 @@ func TestRefreshAiSummaries(t *testing.T) { iom := mock.NewMockIOManager(ctrl) loader := mock.NewMockAiLoader(ctrl) + h := memory.New() + lg := &log.Logger{Handler: h, Level: log.DebugLevel} + logger := lg.WithField("test", true) + + loader.EXPECT().IsAirgapped().Return(true) + + err := RefreshAiSummaries(loader, model.SigLanguage(""), nil, "", "", "", logger, nil) + assert.NoError(t, err) + + assert.Equal(t, len(h.Entries), 1) + + msg := h.Entries[0] + assert.Equal(t, msg.Message, "skipping AI summary update because airgap is enabled") + assert.Equal(t, msg.Level, log.DebugLevel) + + loader.EXPECT().IsAirgapped().Return(false) iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil) iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil) iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) @@ -54,10 +71,8 @@ func TestRefreshAiSummaries(t *testing.T) { return nil }) - logger := log.WithField("test", true) - lastSuccessfulAiUpdate = time.Time{} - err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) + err = RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", repo, branch, logger, iom) assert.NoError(t, err) } diff --git a/server/modules/detections/detengine_helpers_test.go b/server/modules/detections/detengine_helpers_test.go index ab25ff59..c4dd230e 100644 --- a/server/modules/detections/detengine_helpers_test.go +++ b/server/modules/detections/detengine_helpers_test.go @@ -20,7 +20,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/util" "github.com/go-git/go-git/v5/plumbing/transport" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/server/modules/detections/errortracker_test.go b/server/modules/detections/errortracker_test.go index b86aa9a3..42871ef5 100644 --- a/server/modules/detections/errortracker_test.go +++ b/server/modules/detections/errortracker_test.go @@ -9,7 +9,7 @@ import ( "errors" "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestErrorTracker(t *testing.T) { diff --git a/server/modules/detections/integrity_check_test.go b/server/modules/detections/integrity_check_test.go index cf1e62d3..d94270d2 100644 --- a/server/modules/detections/integrity_check_test.go +++ b/server/modules/detections/integrity_check_test.go @@ -9,7 +9,7 @@ import ( "sort" "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestDiffLists(t *testing.T) { diff --git a/server/modules/detections/io_manager_test.go b/server/modules/detections/io_manager_test.go index 30a1f4e1..2d46980d 100644 --- a/server/modules/detections/io_manager_test.go +++ b/server/modules/detections/io_manager_test.go @@ -10,7 +10,8 @@ import ( "testing" "github.com/security-onion-solutions/securityonion-soc/config" - "github.com/tj/assert" + + "github.com/stretchr/testify/assert" ) func TestBuildHttpClient(t *testing.T) { diff --git a/server/modules/detections/mock/mock_ailoader.go b/server/modules/detections/mock/mock_ailoader.go index a93acbf2..ed0bfc4e 100644 --- a/server/modules/detections/mock/mock_ailoader.go +++ b/server/modules/detections/mock/mock_ailoader.go @@ -38,6 +38,20 @@ func (m *MockAiLoader) EXPECT() *MockAiLoaderMockRecorder { return m.recorder } +// IsAirgapped mocks base method. +func (m *MockAiLoader) IsAirgapped() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsAirgapped") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsAirgapped indicates an expected call of IsAirgapped. +func (mr *MockAiLoaderMockRecorder) IsAirgapped() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAirgapped", reflect.TypeOf((*MockAiLoader)(nil).IsAirgapped)) +} + // LoadAuxiliaryData mocks base method. func (m *MockAiLoader) LoadAuxiliaryData(arg0 []*model.AiSummary) error { m.ctrl.T.Helper() diff --git a/server/modules/elastalert/elastalert.go b/server/modules/elastalert/elastalert.go index 6b543b75..4185c6c6 100644 --- a/server/modules/elastalert/elastalert.go +++ b/server/modules/elastalert/elastalert.go @@ -1556,6 +1556,10 @@ func (e *ElastAlertEngine) DuplicateDetection(ctx context.Context, detection *mo return det, nil } +func (e *ElastAlertEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *ElastAlertEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { diff --git a/server/modules/strelka/strelka.go b/server/modules/strelka/strelka.go index 4315b5e2..111a1124 100644 --- a/server/modules/strelka/strelka.go +++ b/server/modules/strelka/strelka.go @@ -1136,6 +1136,10 @@ func (e *StrelkaEngine) DuplicateDetection(ctx context.Context, detection *model return det, nil } +func (e *StrelkaEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *StrelkaEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { diff --git a/server/modules/strelka/strelka_test.go b/server/modules/strelka/strelka_test.go index 30668426..2633c1dc 100644 --- a/server/modules/strelka/strelka_test.go +++ b/server/modules/strelka/strelka_test.go @@ -30,7 +30,7 @@ import ( "github.com/apex/log" "github.com/elastic/go-elasticsearch/v8/esutil" "github.com/samber/lo" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/server/modules/suricata/migration-2.4.70_test.go b/server/modules/suricata/migration-2.4.70_test.go index 44a93414..f3687827 100644 --- a/server/modules/suricata/migration-2.4.70_test.go +++ b/server/modules/suricata/migration-2.4.70_test.go @@ -16,7 +16,7 @@ import ( "github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock" "github.com/security-onion-solutions/securityonion-soc/util" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" ) diff --git a/server/modules/suricata/suricata.go b/server/modules/suricata/suricata.go index a44aaa11..689ecf02 100644 --- a/server/modules/suricata/suricata.go +++ b/server/modules/suricata/suricata.go @@ -1746,6 +1746,10 @@ func (e *SuricataEngine) DuplicateDetection(ctx context.Context, detection *mode return det, nil } +func (e *SuricataEngine) IsAirgapped() bool { + return e.srv.Config.AirgapEnabled +} + func (e *SuricataEngine) LoadAuxiliaryData(summaries []*model.AiSummary) error { sum := &sync.Map{} for _, summary := range summaries { diff --git a/util/strings_test.go b/util/strings_test.go index 7437fc7d..09b1bad0 100644 --- a/util/strings_test.go +++ b/util/strings_test.go @@ -8,7 +8,7 @@ package util import ( "testing" - "github.com/tj/assert" + "github.com/stretchr/testify/assert" ) func TestUnquote(t *testing.T) {