diff --git a/server/modules/detections/ai_summary.go b/server/modules/detections/ai_summary.go index 44d7db6a..9aa2f38b 100644 --- a/server/modules/detections/ai_summary.go +++ b/server/modules/detections/ai_summary.go @@ -26,23 +26,22 @@ type AiLoader interface { //go:generate mockgen -destination mock/mock_ailoader.go -package mock . AiLoader func RefreshAiSummaries(eng AiLoader, lang model.SigLanguage, isRunning *bool, aiRepoPath string, aiRepoUrl string, aiRepoBranch string, logger *log.Entry, iom IOManager) error { - if eng.IsAirgapped() { - logger.Debug("skipping AI summary update because airgap is enabled") - return nil - } + if !eng.IsAirgapped() { + err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) + if err != nil { + if errors.Is(err, ErrModuleStopped) { + return err + } + + logger.WithError(err).WithFields(log.Fields{ + "aiRepoUrl": aiRepoUrl, + "aiRepoPath": aiRepoPath, + }).Error("unable to update AI repo") - err := updateAiRepo(isRunning, aiRepoPath, aiRepoUrl, aiRepoBranch, iom) - if err != nil { - if errors.Is(err, ErrModuleStopped) { return err } - - logger.WithError(err).WithFields(log.Fields{ - "aiRepoUrl": aiRepoUrl, - "aiRepoPath": aiRepoPath, - }).Error("unable to update AI repo") - - return err + } else { + logger.Debug("skipping AI summary update because airgap is enabled") } parser, err := url.Parse(aiRepoUrl) diff --git a/server/modules/detections/ai_summary_test.go b/server/modules/detections/ai_summary_test.go index c2c23cd1..4ffa4d9a 100644 --- a/server/modules/detections/ai_summary_test.go +++ b/server/modules/detections/ai_summary_test.go @@ -20,32 +20,77 @@ func TestRefreshAiSummaries(t *testing.T) { defer ctrl.Finish() isRunning := true - repo := "http://github.com/user/repo1" + localRepo := "file:///tmp/repo1" + repo := "http://github.com/user/repo2" branch := "generated-summaries-published" summaries := `{"87e55c67-46f0-4a7b-a3c6-d473ab7e8392": { "Reviewed": false, "Summary": "ai text goes here"}, "a23077fc-a5ef-427f-92ab-d3de7f56834d": { "Reviewed": true, "Summary": "ai text goes here" } }` iom := mock.NewMockIOManager(ctrl) loader := mock.NewMockAiLoader(ctrl) + // Airgapped test h := memory.New() lg := &log.Logger{Handler: h, Level: log.DebugLevel} logger := lg.WithField("test", true) + // No calls to iom.PullRepo/iom.CloneRepo should be made loader.EXPECT().IsAirgapped().Return(true) + iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) + loader.EXPECT().LoadAuxiliaryData(gomock.Any()).DoAndReturn(func(sums []*model.AiSummary) error { + expected := []*model.AiSummary{ + { + PublicId: "87e55c67-46f0-4a7b-a3c6-d473ab7e8392", + Summary: "ai text goes here", + }, + { + PublicId: "a23077fc-a5ef-427f-92ab-d3de7f56834d", + Reviewed: true, + Summary: "ai text goes here", + }, + } + + sort.Slice(sums, func(i, j int) bool { + return sums[i].PublicId < sums[j].PublicId + }) + + assert.Equal(t, len(expected), len(sums)) + for i := range sums { + assert.Equal(t, *expected[i], *sums[i]) + } - err := RefreshAiSummaries(loader, model.SigLanguage(""), nil, "", "", "", logger, nil) + return nil + }) + + err := RefreshAiSummaries(loader, model.SigLangSigma, &isRunning, "baseRepoFolder", localRepo, "", logger, iom) assert.NoError(t, err) - assert.Equal(t, len(h.Entries), 1) + assert.Equal(t, len(h.Entries), 5) msg := h.Entries[0] assert.Equal(t, msg.Message, "skipping AI summary update because airgap is enabled") assert.Equal(t, msg.Level, log.DebugLevel) + msg = h.Entries[1] + assert.Equal(t, msg.Message, "reading AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[2] + assert.Equal(t, msg.Message, "successfully unmarshalled AI summaries, parsing...") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[3] + assert.Equal(t, msg.Message, "successfully parsed AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + msg = h.Entries[4] + assert.Equal(t, msg.Message, "successfully loaded AI summaries") + assert.Equal(t, msg.Level, log.InfoLevel) + + // non-Airgapped test loader.EXPECT().IsAirgapped().Return(false) iom.EXPECT().ReadDir("baseRepoFolder").Return([]fs.DirEntry{}, nil) - iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo1", repo, &branch).Return(nil) - iom.EXPECT().ReadFile("baseRepoFolder/repo1/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) + iom.EXPECT().CloneRepo(gomock.Any(), "baseRepoFolder/repo2", repo, &branch).Return(nil) + iom.EXPECT().ReadFile("baseRepoFolder/repo2/detections-ai/sigma_summaries.yaml").Return([]byte(summaries), nil) loader.EXPECT().LoadAuxiliaryData(gomock.Any()).DoAndReturn(func(sums []*model.AiSummary) error { expected := []*model.AiSummary{ {