From 2d0042d646205561a3896fff4f382860a5709e38 Mon Sep 17 00:00:00 2001 From: Corey Ogburn Date: Thu, 26 Sep 2024 13:40:04 -0600 Subject: [PATCH] Fix Summaries on Airgap Accidentally prevented the reading and loading of summaries on Airgap in a previous commit. While we do want to skip the call to update the repo, we still want to load the summaries from disk. Updated test to account for newly executed logic. --- server/modules/detections/ai_summary.go | 27 +++++----- server/modules/detections/ai_summary_test.go | 55 ++++++++++++++++++-- 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/server/modules/detections/ai_summary.go b/server/modules/detections/ai_summary.go index 44d7db6a..9aa2f38b 100644 --- a/server/modules/detections/ai_summary.go +++ b/server/modules/detections/ai_summary.go @@ -26,23 +26,22 @@ type AiLoader interface { //go:generate mockgen -destination mock/mock_ailoader.go -package mock . AiLoader func RefreshAiSummaries(eng AiLoader, lang model.SigLanguage, isRunning *bool, aiRepoPath string, aiRepoUrl string, aiRepoBranch string, logger *log.Entry, iom IOManager) error { - if eng.IsAirgapped() { - logger.Debug("skipping AI summary update because airgap is enabled") - return nil - } + if !eng.IsAirgapped() { + err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) + if err != nil { + if errors.Is(err, ErrModuleStopped) { + return err + } + + logger.WithError(err).WithFields(log.Fields{ + "aiRepoUrl": aiRepoUrl, + "aiRepoPath": aiRepoPath, + }).Error("unable to update AI repo") - err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) - if err != nil { - if errors.Is(err, ErrModuleStopped) { return err } - - logger.WithError(err).WithFields(log.Fields{ - "aiRepoUrl": aiRepoUrl, - "aiRepoPath": aiRepoPath, - }).Error("unable to update AI repo") - - return err + } else { + logger.Debug("skipping AI summary update because airgap is enabled") } parser, err := url.Parse(aiRepoUrl) diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go index c2c23cd1..4ffa4d9a 100644 --- a/server/modules/detections/ai_summary_test.go +++ b/server/modules/detections/ai_summary_test.go @@ -20,32 +20,77 @@ func TestRefreshAiSummaries(t *testing.T) { defer ctrl.Finish() isRunning := true - repo := "http://github.com/user/repo1" + localRepo := "file:///tmp/repo1" + repo := "http://github.com/user/repo2" branch := "generated-summaries-published" summaries := `{"87e55c67-46f0-4a7b-a3c6-d473ab7e8392": { "Reviewed": false, "Summary": "ai text goes here"}, "a23077fc-a5ef-427f-92ab-d3de7f56834d": { "Reviewed": true, "Summary": "ai text goes here" } }` iom := mock.NewMockIOManager(ctrl) loader := mock.NewMockAiLoader(ctrl) + // Airgapped test h := memory.New() lg := &log.Logger{Handler: h, Level: log.DebugLevel} logger := lg.WithField("test", true) + // No calls to iom.PullRepo/iom.CloneRepo should be made loader.EXPECT().IsAirgapped().Return(true) + iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) + loader.EXPECT().LoadAuxiliaryData(gomock.Any()).DoAndReturn(func(sums []*model.AiSummary) error { + expected := []*model.AiSummary{ + { + PublicId: "87e55c67-46f0-4a7b-a3c6-d473ab7e8392", + Summary: "ai text goes here", + }, + { + PublicId: "a23077fc-a5ef-427f-92ab-d3de7f56834d", + Reviewed: true, + Summary: "ai text goes here", + }, + } + + sort.Slice(sums, func(i, j int) bool { + return sums[i].PublicId < sums[j].PublicId + }) + + assert.Equal(t, len(expected), len(sums)) + for i := range sums { + assert.Equal(t, *expected[i], *sums[i]) + } - err := RefreshAiSummaries(loader, model.SigLanguage(""), nil, "", "", "", logger, nil) + return nil + }) + + err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", localRepo, "", logger, iom) assert.NoError(t, err) - assert.Equal(t, len(h.Entries), 1) + assert.Equal(t, len(h.Entries), 5) msg := h.Entries[0] assert.Equal(t, msg.Message, "skipping AI summary update because airgap is enabled") assert.Equal(t, msg.Level, log.DebugLevel) + msg = h.Entries[1] + assert.Equal(t, msg.Message, "reading AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[2] + assert.Equal(t, msg.Message, "successfully unmarshalled AI summaries, parsing...") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[3] + assert.Equal(t, msg.Message, "successfully parsed AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[4] + assert.Equal(t, msg.Message, "successfully loaded AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + // non-Airgapped test loader.EXPECT().IsAirgapped().Return(false) iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil) - iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil) - iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) + iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo2", repo, &branch).Return(nil) + iom.EXPECT().ReadFile("baseRepoFolder/repo2/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) loader.EXPECT().LoadAuxiliaryData(gomock.Any()).DoAndReturn(func(sums []*model.AiSummary) error { expected := []*model.AiSummary{ {