Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cogburn/refactor changed by user #611

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions model/detection.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,26 @@ type DetectionEngine struct {

type Detection struct {
Auditable
PublicID string `json:"publicId"`
Title string `json:"title"`
Severity Severity `json:"severity"`
Author string `json:"author"`
Category string `json:"category,omitempty"`
Description string `json:"description"`
Content string `json:"content"`
IsEnabled bool `json:"isEnabled"`
IsReporting bool `json:"isReporting"`
IsCommunity bool `json:"isCommunity"`
Engine EngineName `json:"engine"`
Language SigLanguage `json:"language"`
Overrides []*Override `json:"overrides"` // Tuning
Tags []string `json:"tags"`
Ruleset string `json:"ruleset"`
License string `json:"license"`
PendingDelete bool `json:"-"` // this is a transient field, not stored in the database
PublicID string `json:"publicId"`
Title string `json:"title"`
Severity Severity `json:"severity"`
Author string `json:"author"`
Category string `json:"category,omitempty"`
Description string `json:"description"`
Content string `json:"content"`
IsEnabled bool `json:"isEnabled"`
IsReporting bool `json:"isReporting"`
IsCommunity bool `json:"isCommunity"`
Engine EngineName `json:"engine"`
Language SigLanguage `json:"language"`
Overrides []*Override `json:"overrides"` // Tuning
Tags []string `json:"tags"`
Ruleset string `json:"ruleset"`
License string `json:"license"`

// these are transient fields, not stored in the database
PendingDelete bool `json:"-"`
PersistChange bool `json:"-"`

// elastalert - sigma only
Product string `json:"product,omitempty"`
Expand Down
12 changes: 4 additions & 8 deletions server/detectionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ func (h *DetectionHandler) getByPublicId(w http.ResponseWriter, r *http.Request)

func (h *DetectionHandler) createDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

detect := &model.Detection{}

Expand Down Expand Up @@ -270,7 +269,6 @@ func (h *DetectionHandler) getDetectionHistory(w http.ResponseWriter, r *http.Re

func (h *DetectionHandler) duplicateDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

detectId := chi.URLParam(r, "id")

Expand Down Expand Up @@ -303,7 +301,6 @@ func (h *DetectionHandler) duplicateDetection(w http.ResponseWriter, r *http.Req

func (h *DetectionHandler) updateDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

detect := &model.Detection{}

Expand Down Expand Up @@ -369,6 +366,8 @@ func (h *DetectionHandler) updateDetection(w http.ResponseWriter, r *http.Reques
return
}

detect.PersistChange = true

errMap, err := SyncLocalDetections(ctx, h.server, []*model.Detection{detect})
if err != nil {
fixed := false
Expand Down Expand Up @@ -412,7 +411,6 @@ func (h *DetectionHandler) updateDetection(w http.ResponseWriter, r *http.Reques

func (h *DetectionHandler) deleteDetection(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

id := chi.URLParam(r, "id")

Expand Down Expand Up @@ -529,7 +527,6 @@ func (h *DetectionHandler) bulkUpdateDetection(w http.ResponseWriter, r *http.Re

noTimeOutCtx := context.WithValue(context.Background(), web.ContextKeyRequestor, ctx.Value(web.ContextKeyRequestor).(*model.User))
noTimeOutCtx = context.WithValue(noTimeOutCtx, web.ContextKeyRequestorId, ctx.Value(web.ContextKeyRequestorId).(string))
noTimeOutCtx = web.MarkChangedByUser(noTimeOutCtx, true)

go h.bulkUpdateDetectionAsync(noTimeOutCtx, body, detects, logger)

Expand Down Expand Up @@ -728,6 +725,8 @@ func (h *DetectionHandler) bulkUpdateDetectionAsync(ctx context.Context, body *B
det.PendingDelete = true
}

det.PersistChange = true

dirty = append(dirty, det)
}

Expand Down Expand Up @@ -794,7 +793,6 @@ func SyncLocalDetections(ctx context.Context, srv *Server, detections []*model.D

func (h *DetectionHandler) createComment(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

detectId := chi.URLParam(r, "id")

Expand Down Expand Up @@ -833,7 +831,6 @@ func (h *DetectionHandler) getDetectionComment(w http.ResponseWriter, r *http.Re

func (h *DetectionHandler) updateComment(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

commentId := chi.URLParam(r, "id")

Expand All @@ -858,7 +855,6 @@ func (h *DetectionHandler) updateComment(w http.ResponseWriter, r *http.Request)

func (h *DetectionHandler) deleteComment(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = web.MarkChangedByUser(ctx, true)

commentId := chi.URLParam(r, "id")

Expand Down
4 changes: 2 additions & 2 deletions server/modules/suricata/migration-2.4.70.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/security-onion-solutions/securityonion-soc/model"
"github.com/security-onion-solutions/securityonion-soc/web"

"github.com/apex/log"
"gopkg.in/yaml.v3"
Expand All @@ -37,7 +36,7 @@ func (e *SuricataEngine) Migration2470(statePath string) error {

log.Info("suricata is now migrating to 2.4.70") // for support

ctx := web.MarkChangedByUser(e.srv.Context, true)
ctx := e.srv.Context

// read in idstools.yaml
enabled, disabled, err := e.m2470LoadEnabledDisabled()
Expand Down Expand Up @@ -106,6 +105,7 @@ func (e *SuricataEngine) Migration2470(statePath string) error {
continue
}

det.PersistChange = true
det.Kind = ""

_, err := e.srv.Detectionstore.UpdateDetection(ctx, det)
Expand Down
4 changes: 1 addition & 3 deletions server/modules/suricata/suricata.go
Original file line number Diff line number Diff line change
Expand Up @@ -1124,8 +1124,6 @@ func (e *SuricataEngine) syncCommunityDetections(ctx context.Context, logger *lo
}()
errMap = map[string]string{}

changedByUser := web.IsChangedByUser(ctx)

if logger == nil {
logger = log.WithField("detectionEngine", model.EngineNameSuricata)
}
Expand Down Expand Up @@ -1241,7 +1239,7 @@ func (e *SuricataEngine) syncCommunityDetections(ctx context.Context, logger *lo
_, inEnabled := enabledIndex[sid]
_, inDisabled := disabledIndex[sid]

if changedByUser || inEnabled || inDisabled || modifiedByFilter {
if detect.PersistChange || inEnabled || inDisabled || modifiedByFilter {
// update enabled
enabledLines = updateEnabled(enabledLines, enabledIndex, sid, isFlowbits, detect)

Expand Down
9 changes: 7 additions & 2 deletions server/modules/suricata/suricata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"github.com/security-onion-solutions/securityonion-soc/server/modules/detections/handmock"
"github.com/security-onion-solutions/securityonion-soc/server/modules/detections/mock"
"github.com/security-onion-solutions/securityonion-soc/util"
"github.com/security-onion-solutions/securityonion-soc/web"

"github.com/apex/log"
"github.com/elastic/go-elasticsearch/v8/esutil"
Expand Down Expand Up @@ -1078,6 +1077,8 @@ func TestSyncCommunitySuricata(t *testing.T) {
},
}

ctx := context.Background()

for _, test := range table {
test := test
t.Run(test.Name, func(t *testing.T) {
Expand All @@ -1096,7 +1097,11 @@ func TestSyncCommunitySuricata(t *testing.T) {

mod.isRunning = true

ctx := web.MarkChangedByUser(context.Background(), test.ChangedByUser)
if test.ChangedByUser {
for _, detect := range test.Detections {
detect.PersistChange = true
}
}

errMap, err := mod.syncCommunityDetections(ctx, nil, test.Detections, false, test.InitialSettings)

Expand Down
23 changes: 0 additions & 23 deletions web/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@ import (
"github.com/apex/log"
)

type contextKey string

const (
ContextKeyChangedByUser contextKey = "changedByUser"
)

func Middleware(host *Host, isWS bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -200,20 +194,3 @@ func isNil(i interface{}) bool {
}
return false
}

func MarkChangedByUser(ctx context.Context, value bool) context.Context {
return context.WithValue(ctx, ContextKeyChangedByUser, value)
}

func IsChangedByUser(ctx context.Context) bool {
if ctx == nil {
return false
}

v := ctx.Value(ContextKeyChangedByUser)
if v == nil {
return false
}

return v.(bool)
}
18 changes: 2 additions & 16 deletions web/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,22 +187,8 @@ func compareJSON(jsn1 []byte, jsn2 []byte) (success bool, err error) {
var two interface{}

// this is guarded by prettyPrint
json.Unmarshal(jsn1, &one)
json.Unmarshal(jsn2, &two)
_ = json.Unmarshal(jsn1, &one)
_ = json.Unmarshal(jsn2, &two)

return reflect.DeepEqual(one, two), nil
}

func TestChangedByUser(t *testing.T) {
t.Parallel()

assert.False(t, IsChangedByUser(nil))

assert.False(t, IsChangedByUser(context.Background()))

ctx := MarkChangedByUser(context.Background(), true)
assert.True(t, IsChangedByUser(ctx))

ctx = MarkChangedByUser(context.Background(), false)
assert.False(t, IsChangedByUser(ctx))
}
Loading