Skip to content

Commit

Permalink
detector trigger detection types
Browse files Browse the repository at this point in the history
Signed-off-by: Surya Sashank Nistala <snistala@amazon.com>
  • Loading branch information
eirsep committed Oct 22, 2023
1 parent ae084e7 commit 28f3ba8
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,53 @@ public class DetectorTrigger implements Writeable, ToXContentObject {

private List<Action> actions;

/**
* detection type is a list of values that tells us what queries is the trigger trying to match - rules-based or threat_intel-based or both
*/
private List<String> detectionTypes; // todo make it enum supports 'rules', 'threat_intel'

private static final String ID_FIELD = "id";

private static final String SEVERITY_FIELD = "severity";
private static final String RULE_TYPES_FIELD = "types";
private static final String RULE_IDS_FIELD = "ids";
private static final String RULE_SEV_LEVELS_FIELD = "sev_levels";
private static final String RULE_TAGS_FIELD = "tags";
private static final String ACTIONS_FIELD = "actions";
private static final String DETECTION_TYPES_FIELD = "detection_types";

public static final String RULES_DETECTION_TYPE = "rules";
public static final String THREAT_INTEL_DETECTION_TYPE = "threat_intel";

public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry(
DetectorTrigger.class,
new ParseField(ID_FIELD),
DetectorTrigger::parse
);

public DetectorTrigger(String id, String name, String severity, List<String> ruleTypes, List<String> ruleIds, List<String> ruleSeverityLevels, List<String> tags, List<Action> actions) {
this.id = id == null? UUIDs.base64UUID(): id;
public DetectorTrigger(String id,
String name,
String severity,
List<String> ruleTypes,
List<String> ruleIds,
List<String> ruleSeverityLevels,
List<String> tags,
List<Action> actions,
List<String> detectionTypes) {
this.id = id == null ? UUIDs.base64UUID() : id;
this.name = name;
this.severity = severity;
this.ruleTypes = ruleTypes.stream()
.map( e -> e.toLowerCase(Locale.ROOT))
.map(e -> e.toLowerCase(Locale.ROOT))
.collect(Collectors.toList());
this.ruleIds = ruleIds;
this.ruleSeverityLevels = ruleSeverityLevels;
this.tags = tags;
this.actions = actions;
this.detectionTypes = detectionTypes;
if(this.detectionTypes.isEmpty()) {
this.detectionTypes = Collections.singletonList(RULES_DETECTION_TYPE); // for backward compatibility
}
}

public DetectorTrigger(StreamInput sin) throws IOException {
Expand All @@ -85,7 +107,8 @@ public DetectorTrigger(StreamInput sin) throws IOException {
sin.readStringList(),
sin.readStringList(),
sin.readStringList(),
sin.readList(Action::readFrom)
sin.readList(Action::readFrom),
sin.readStringList()
);
}

Expand All @@ -95,7 +118,8 @@ public Map<String, Object> asTemplateArg() {
RULE_IDS_FIELD, ruleIds,
RULE_SEV_LEVELS_FIELD, ruleSeverityLevels,
RULE_TAGS_FIELD, tags,
ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg)
ACTIONS_FIELD, actions.stream().map(Action::asTemplateArg),
DETECTION_TYPES_FIELD, detectionTypes
);
}

Expand All @@ -109,6 +133,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(ruleSeverityLevels);
out.writeStringCollection(tags);
out.writeCollection(actions);
out.writeStringCollection(detectionTypes);
}

@Override
Expand All @@ -128,6 +153,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
Action[] actionArray = new Action[]{};
actionArray = actions.toArray(actionArray);

String[] detectionTypesArray = new String[]{};
detectionTypesArray = detectionTypes.toArray(detectionTypesArray);

return builder.startObject()
.field(ID_FIELD, id)
.field(Detector.NAME_FIELD, name)
Expand All @@ -137,6 +165,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(RULE_SEV_LEVELS_FIELD, ruleSevLevelArray)
.field(RULE_TAGS_FIELD, tagArray)
.field(ACTIONS_FIELD, actionArray)
.field(DETECTION_TYPES_FIELD, detectionTypesArray)
.endObject();
}

Expand All @@ -149,6 +178,7 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException {
List<String> ruleSeverityLevels = new ArrayList<>();
List<String> tags = new ArrayList<>();
List<Action> actions = new ArrayList<>();
List<String> detectionTypes = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -193,6 +223,13 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException {
tags.add(tag);
}
break;
case DETECTION_TYPES_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
String dt = xcp.text();
detectionTypes.add(dt);
}
break;
case ACTIONS_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
Expand All @@ -204,8 +241,10 @@ public static DetectorTrigger parse(XContentParser xcp) throws IOException {
xcp.skipChildren();
}
}

return new DetectorTrigger(id, name, severity, ruleTypes, ruleNames, ruleSeverityLevels, tags, actions);
if(detectionTypes.isEmpty()) {
detectionTypes.add(RULES_DETECTION_TYPE); // for backward compatibility
}
return new DetectorTrigger(id, name, severity, ruleTypes, ruleNames, ruleSeverityLevels, tags, actions, detectionTypes);
}

public static DetectorTrigger readFrom(StreamInput sin) throws IOException {
Expand All @@ -227,71 +266,83 @@ public int hashCode() {

public Script convertToCondition() {
StringBuilder condition = new StringBuilder();

boolean triggerFlag = false;

StringBuilder ruleTypeBuilder = new StringBuilder();
int size = ruleTypes.size();
for (int idx = 0; idx < size; ++idx) {
ruleTypeBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleTypes.get(idx)));
if (idx < size - 1) {
ruleTypeBuilder.append(" || ");
int size = 0;
if (detectionTypes.contains(RULES_DETECTION_TYPE)) { // trigger should match rules based queries based on conditions
StringBuilder ruleTypeBuilder = new StringBuilder();
size = ruleTypes.size();
for (int idx = 0; idx < size; ++idx) {
ruleTypeBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleTypes.get(idx)));
if (idx < size - 1) {
ruleTypeBuilder.append(" || ");
}
}
if (size > 0) {
condition.append("(").append(ruleTypeBuilder).append(")");
triggerFlag = true;
}
}
if (size > 0) {
condition.append("(").append(ruleTypeBuilder).append(")");
triggerFlag = true;
}

StringBuilder ruleNameBuilder = new StringBuilder();
size = ruleIds.size();
for (int idx = 0; idx < size; ++idx) {
ruleNameBuilder.append(String.format(Locale.getDefault(), "query[name=%s]", ruleIds.get(idx)));
if (idx < size - 1) {
ruleNameBuilder.append(" || ");
StringBuilder ruleNameBuilder = new StringBuilder();
size = ruleIds.size();
for (int idx = 0; idx < size; ++idx) {
ruleNameBuilder.append(String.format(Locale.getDefault(), "query[name=%s]", ruleIds.get(idx)));
if (idx < size - 1) {
ruleNameBuilder.append(" || ");
}
}
}
if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(ruleNameBuilder).append(")");
} else {
condition.append("(").append(ruleNameBuilder).append(")");
triggerFlag = true;
if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(ruleNameBuilder).append(")");
} else {
condition.append("(").append(ruleNameBuilder).append(")");
triggerFlag = true;
}
}
}

StringBuilder ruleSevLevelBuilder = new StringBuilder();
size = ruleSeverityLevels.size();
for (int idx = 0; idx < size; ++idx) {
ruleSevLevelBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleSeverityLevels.get(idx)));
if (idx < size - 1) {
ruleSevLevelBuilder.append(" || ");
StringBuilder ruleSevLevelBuilder = new StringBuilder();
size = ruleSeverityLevels.size();
for (int idx = 0; idx < size; ++idx) {
ruleSevLevelBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", ruleSeverityLevels.get(idx)));
if (idx < size - 1) {
ruleSevLevelBuilder.append(" || ");
}
}
}

if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(ruleSevLevelBuilder).append(")");
} else {
condition.append("(").append(ruleSevLevelBuilder).append(")");
triggerFlag = true;
if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(ruleSevLevelBuilder).append(")");
} else {
condition.append("(").append(ruleSevLevelBuilder).append(")");
triggerFlag = true;
}
}
}

StringBuilder tagBuilder = new StringBuilder();
size = tags.size();
for (int idx = 0; idx < size; ++idx) {
tagBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", tags.get(idx)));
if (idx < size - 1) {
ruleSevLevelBuilder.append(" || ");
StringBuilder tagBuilder = new StringBuilder();
size = tags.size();
for (int idx = 0; idx < size; ++idx) {
tagBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", tags.get(idx)));
if (idx < size - 1) {
ruleSevLevelBuilder.append(" || ");
}
}
}

if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(tagBuilder).append(")");
} else {
condition.append("(").append(tagBuilder).append(")");
if (size > 0) {
if (triggerFlag) {
condition.append(" && ").append("(").append(tagBuilder).append(")");
} else {
condition.append("(").append(tagBuilder).append(")");
}
}
}
if(detectionTypes.contains(THREAT_INTEL_DETECTION_TYPE)) {
StringBuilder threatIntelClauseBuilder = new StringBuilder();
threatIntelClauseBuilder.append(String.format(Locale.getDefault(), "query[tag=%s]", "threat_intel"));
if (condition.length() > 0) {
condition.append(" || ");
}
condition.append("(").append(threatIntelClauseBuilder).append(")");
}

return new Script(condition.toString());
Expand Down Expand Up @@ -321,6 +372,10 @@ public List<String> getRuleSeverityLevels() {
return ruleSeverityLevels;
}

public List<String > getDetectionTypes() {
return detectionTypes;
}

public List<String> getTags() {
return tags;
}
Expand All @@ -329,8 +384,8 @@ public List<Action> getActions() {
List<Action> transformedActions = new ArrayList<>();

if (actions != null) {
for (Action action: actions) {
String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode(): "";
for (Action action : actions) {
String subjectTemplate = action.getSubjectTemplate() != null ? action.getSubjectTemplate().getIdOrCode() : "";
subjectTemplate = subjectTemplate.replace("{{ctx.detector", "{{ctx.monitor");

action.getMessageTemplate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.securityanalytics.action.IndexDetectorRequest;
import org.opensearch.securityanalytics.action.IndexDetectorResponse;
import org.opensearch.securityanalytics.model.Detector;
import org.opensearch.securityanalytics.model.DetectorTrigger;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.RestHandlerUtils;

Expand Down Expand Up @@ -67,11 +68,26 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli

Detector detector = Detector.parse(xcp, id, null);
detector.setLastUpdateTime(Instant.now());
validateDetectorTriggers(detector);

IndexDetectorRequest indexDetectorRequest = new IndexDetectorRequest(id, refreshPolicy, request.method(), detector);
return channel -> client.execute(IndexDetectorAction.INSTANCE, indexDetectorRequest, indexDetectorResponse(channel, request.method()));
}

private static void validateDetectorTriggers(Detector detector) {
if(detector.getTriggers() != null) {
for (DetectorTrigger trigger : detector.getTriggers()) {
if(trigger.getDetectionTypes().isEmpty())
throw new IllegalArgumentException(String.format(Locale.ROOT,"Trigger [%s] should mention at least one detection type but found none", trigger.getName()));
for (String detectionType : trigger.getDetectionTypes()) {
if(false == (DetectorTrigger.THREAT_INTEL_DETECTION_TYPE.equals(detectionType) || DetectorTrigger.RULES_DETECTION_TYPE.equals(detectionType))) {
throw new IllegalArgumentException(String.format(Locale.ROOT,"Trigger [%s] has unsupported detection type [%s]", trigger.getName(), detectionType));
}
}
}
}
}

private RestResponseListener<IndexDetectorResponse> indexDetectorResponse(RestChannel channel, RestRequest.Method restMethod) {
return new RestResponseListener<>(channel) {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ public List<DocLevelQuery> createDocLevelQueriesFromThreatIntelList(
constructId(detector, entry.getKey()), tifdList.get(0).getFeedId(),
Collections.emptyList(),
String.format(query, field),
List.of("threat_intel", entry.getKey() /*ioc_type*/)
List.of(
"threat_intel",
String.format("ioc_type:%s", entry.getKey()),
String.format("field:%s", field),
String.format("feed_name:%s", tifdList.get(0).getFeedId())
)
));
}
}
Expand Down Expand Up @@ -148,7 +153,7 @@ public void onFailure(Exception e) {
}

private static String constructId(Detector detector, String iocType) {
return detector.getName() + "_threat_intel_" + iocType + "_" + UUID.randomUUID();
return "threat_intel_" + UUID.randomUUID();
}

/** Updates all detectors having threat intel detection enabled with the latest threat intel feed data*/
Expand Down
3 changes: 0 additions & 3 deletions src/main/resources/feed/config/feeds.yml

This file was deleted.

12 changes: 0 additions & 12 deletions src/main/resources/feed/config/feeds/otx.yml

This file was deleted.

Loading

0 comments on commit 28f3ba8

Please sign in to comment.