Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New term optimisation #26

Merged
merged 12 commits into from
Sep 4, 2023
Binary file added .DS_Store
Binary file not shown.
21 changes: 14 additions & 7 deletions elastalert/elastalert.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def get_hits(self, rule, starttime, endtime, index, scroll=False):

def get_new_terms_data(self, rule, starttime, endtime, field):
new_terms = []
new_counts = []

rule_inst = rule["type"]
try:
Expand All @@ -467,17 +468,22 @@ def get_new_terms_data(self, rule, starttime, endtime, field):
buckets = res['aggregations']['values']['buckets']
if type(field) == list:
for bucket in buckets:
new_terms += rule_inst.flatten_aggregation_hierarchy(bucket)
else:
new_terms = [bucket['key'] for bucket in buckets]
keys, counts = rule_inst.flatten_aggregation_hierarchy(bucket)
new_terms += keys
new_counts += counts

else:
for bucket in buckets:
new_terms.append(bucket['key'])
new_counts.append(bucket['doc_count'])

except ElasticsearchException as e:
if len(str(e)) > 1024:
e = str(e)[:1024] + '... (%d characters removed)' % (len(str(e)) - 1024)
self.handle_error('Error running new terms query: %s' % (e), {'rule': rule['name'], 'query': query})
return []

return new_terms
return new_terms, counts



Expand All @@ -486,12 +492,13 @@ def get_new_terms(self,rule, starttime, endtime):
data = {}

for field in rule['fields']:
new_terms = self.get_new_terms_data(rule,starttime,endtime,field)
new_terms, counts = self.get_new_terms_data(rule,starttime,endtime,field)
self.thread_data.num_hits += len(new_terms)
field_data = ( new_terms, counts )
if type(field) == list:
data[tuple(field)] = new_terms
data[tuple(field)] = field_data
else:
data[field] = new_terms
data[field] = field_data

lt = rule.get('use_local_time')
status_log = "Queried rule %s from %s to %s: %s / %s hits" % (
Expand Down
160 changes: 123 additions & 37 deletions elastalert/ruletypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import datetime
import sys
import time
import itertools


from sortedcontainers import SortedKeyList as sortedlist

Expand Down Expand Up @@ -410,6 +412,68 @@ def append_middle(self, event):
self.running_count += event[1]
self.data.rotate(-rotation)

# For each field, a term window is created. Their new_term_windows are maintained within
class TermsWindow:

def __init__(self, term_window_size, ts_field , threshold, threshold_window_size, get_ts):
self.term_window_size = term_window_size
self.ts_field = ts_field
self.threshold = threshold
self.threshold_window_size = threshold_window_size
self.get_ts = get_ts

self.data = sortedlist(key= lambda x: x[0])
self.values = set()
self.new_terms = {}
self.count_dict = {}

def append_keys(self, timestamp, keys, counts):
for i in range(len(keys)):
self.append_key(keys[i], counts[i])

self.data.add((timestamp, keys,counts))
self.adjust_window(till=timestamp - self.term_window_size)

def append_key(self,key,count):
ajaywk7 marked this conversation as resolved.
Show resolved Hide resolved
if key not in self.count_dict:
self.count_dict[key] = 0
self.count_dict[key] += count
self.values.add(key)


def append_and_get_matches(self, timestamp, keys, counts):

self.adjust_window(till = timestamp - self.term_window_size)

matched_keys = []
matched_counts = []

for i in range(len(keys)):
if keys[i] not in self.values:
key = keys.pop(i)
count = counts.pop(i)
event = ({self.ts_field: timestamp}, count)
window = self.new_terms.setdefault( key , EventWindow(self.threshold_window_size, getTimestamp=self.get_ts))
window.append(event)
if window.count() >= self.threshold:
matched_keys.append(key)
matched_counts.append(window.count())
self.new_terms.pop(key)
self.append_keys(timestamp,keys + matched_keys,counts + matched_counts)

return matched_keys, matched_counts

def adjust_window(self,till):
while len(self.data)!=0 and self.data[0][0] < till:
timestamp, keys, counts = self.data.pop(0)
for i in range(len(keys)):
self.count_dict[keys[i]] -= counts[i]
if self.count_dict[keys[i]] <= 0:
self.count_dict.pop(keys[i])
self.values.discard(keys[i])




class SpikeRule(RuleType):
""" A rule that uses two sliding windows to compare relative event frequency. """
Expand Down Expand Up @@ -676,9 +740,14 @@ class NewTermsRule(RuleType):

def __init__(self, rule, args=None):
super(NewTermsRule, self).__init__(rule, args)
self.seen_values = {}
self.term_windows = {}
self.last_updated_at = None
self.es = kibana_adapter_client(self.rules)
self.ts_field = self.rules.get('timestamp_field', '@timestamp')
self.get_ts = new_get_event_ts(self.ts_field)
self.new_terms = {}

self.threshold = rule.get('threshold',0)

# terms_window_size : Default & Upperbound - 7 Days
self.window_size = min(datetime.timedelta(**self.rules.get('terms_window_size', {'days': 7})), datetime.timedelta(**{'days': 7}))
Expand All @@ -692,6 +761,9 @@ def __init__(self, rule, args=None):
# refresh_interval : Default - 500, Upperbound: 1000
self.terms_size = min(self.rules.get('terms_size', 500),1000)

# threshold_window_size
self.threshold_window_size = min( datetime.timedelta(**self.rules.get('threshold_window_size', {'hours': 1})), datetime.timedelta(**{'days': 2}) )

# Allow the use of query_key or fields
if 'fields' not in self.rules:
if 'query_key' not in self.rules:
Expand All @@ -713,17 +785,12 @@ def __init__(self, rule, args=None):
if self.rules.get('use_keyword_postfix', False): # making it false by default as we wont use the keyword suffix
elastalert_logger.warn('Warning: If query_key is a non-keyword field, you must set '
'use_keyword_postfix to false, or add .keyword/.raw to your query_key.')

def should_refresh_terms(self):
return self.last_updated_at is None or self.last_updated_at < ( ts_now() - self.refresh_interval)

def update_terms(self,args=None):
try:
self.get_all_terms(args=args)
except Exception as e:
# Refuse to start if we cannot get existing terms
raise EAException('Error searching for existing terms: %s' % (repr(e))).with_traceback(sys.exc_info()[2])



def get_new_term_query(self,starttime,endtime,field):
Expand Down Expand Up @@ -771,7 +838,7 @@ def get_new_term_query(self,starttime,endtime,field):

# For composite keys, we will need to perform sub-aggregations
if type(field) == list:
self.seen_values.setdefault(tuple(field), [])
self.term_windows.setdefault(tuple(field), TermsWindow(self.window_size, self.ts_field , self.threshold, self.threshold_window_size, self.get_ts))
level = query['aggs']
# Iterate on each part of the composite key and add a sub aggs clause to the elastic search query
for i, sub_field in enumerate(field):
Expand All @@ -784,7 +851,7 @@ def get_new_term_query(self,starttime,endtime,field):
level['values']['aggs'] = {'values': {'terms': copy.deepcopy(field_name)}}
level = level['values']['aggs']
else:
self.seen_values.setdefault(field, [])
self.term_windows.setdefault(field, TermsWindow(self.window_size, self.ts_field , self.threshold, self.threshold_window_size, self.get_ts))
# For non-composite keys, only a single agg is needed
if self.rules.get('use_keyword_postfix', False):# making it false by default as we wont use the keyword suffix
field_name['field'] = add_raw_postfix(field, True)
Expand Down Expand Up @@ -820,32 +887,41 @@ def get_all_terms(self,args):

res = self.es.msearch(msearch_query,request_timeout=50)
res = res['responses'][0]

if 'aggregations' in res:
buckets = res['aggregations']['values']['buckets']

term_window = self.term_windows[self.get_lookup_key(field)]
keys = []
counts = []

if type(field) == list:
# For composite keys, make the lookup based on all fields
# Make it a tuple since it can be hashed and used in dictionary lookups
for bucket in buckets:
# We need to walk down the hierarchy and obtain the value at each level
self.seen_values[tuple(field)] += self.flatten_aggregation_hierarchy(bucket)
keys, counts = self.flatten_aggregation_hierarchy(bucket)
else:
keys = [bucket['key'] for bucket in buckets]
self.seen_values[field] += keys
for bucket in buckets:
keys.append(bucket['key'])
counts.append(bucket['doc_count'])

term_window.append_keys(tmp_end,keys,counts)

else:
if type(field) == list:
self.seen_values.setdefault(tuple(field), [])
self.term_windows.setdefault(tuple(field), TermsWindow(self.window_size, self.ts_field , self.threshold, self.threshold_window_size, self.get_ts))
else:
self.seen_values.setdefault(field, [])
self.term_windows.setdefault(field, TermsWindow(self.window_size, self.ts_field , self.threshold, self.threshold_window_size, self.get_ts))
if tmp_start == tmp_end:
break
tmp_start = tmp_end
tmp_end = min(tmp_start + self.step, end)
query = self.get_new_term_query(tmp_start,tmp_end,field)


for key, values in self.seen_values.items():
if not values:
for key, window in self.term_windows.items():
if not window.values:
if type(key) == tuple:
# If we don't have any results, it could either be because of the absence of any baseline data
# OR it may be because the composite key contained a non-primitive type. Either way, give the
Expand All @@ -857,9 +933,8 @@ def get_all_terms(self,args):
else:
elastalert_logger.info('Found no values for %s' % (field))
continue
self.seen_values[key] = list(set(values))
elastalert_logger.info('Found %s unique values for %s' % (len(set(values)), key))
self.last_updated_at = ts_now()
elastalert_logger.info('Found %s unique values for %s' % (len(window.values), key))
# self.last_updated_at = ts_now()

def flatten_aggregation_hierarchy(self, root, hierarchy_tuple=()):
""" For nested aggregations, the results come back in the following format:
Expand Down Expand Up @@ -950,36 +1025,44 @@ def flatten_aggregation_hierarchy(self, root, hierarchy_tuple=()):
A similar formatting will be performed in the add_data method and used as the basis for comparison

"""
results = []
keys = []
counts = []
# There are more aggregation hierarchies left. Traverse them.
if 'values' in root:
results += self.flatten_aggregation_hierarchy(root['values']['buckets'], hierarchy_tuple + (root['key'],))
new_terms, new_counts = self.flatten_aggregation_hierarchy(root['values']['buckets'], hierarchy_tuple + (root['key'],))
keys += new_terms
counts += new_counts
else:
# We've gotten to a sub-aggregation, which may have further sub-aggregations
# See if we need to traverse further
for node in root:
if 'values' in node:
results += self.flatten_aggregation_hierarchy(node, hierarchy_tuple)
new_terms, new_counts = self.flatten_aggregation_hierarchy(node, hierarchy_tuple)
keys += new_terms
counts += new_counts
else:
results.append(hierarchy_tuple + (node['key'],))
return results
keys.append(hierarchy_tuple + (node['key'],))
counts.append(node['doc_count'])
return keys, counts

def add_new_term_data(self, payload):
if self.should_refresh_terms():
self.update_terms()
# if self.should_refresh_terms():
# self.update_terms()
timestamp = list(payload.keys())[0]
data = payload[timestamp]
for field in self.fields:
lookup_key =tuple(field) if type(field) == list else field
for value in data[lookup_key]:
if value not in self.seen_values[lookup_key]:
match = {
"field": lookup_key,
self.rules['timestamp_field']: timestamp,
'new_value': tuple(value) if type(field) == list else value
}
self.add_match(copy.deepcopy(match))
self.seen_values[lookup_key].append(value)
lookup_key = self.get_lookup_key(field)
keys, counts = data[lookup_key]
unmatched_keys, unmatched_counts = self.term_windows[lookup_key].append_and_get_matches(timestamp, keys, counts)
# append and get all match keys and counts
for (key, count) in zip(unmatched_keys, unmatched_counts):
match = {
"field": lookup_key,
self.rules['timestamp_field']: timestamp,
"new_value": tuple(key) if type(key) == list else key,
"hits" : count
}
self.add_match(copy.deepcopy(match))

def add_data(self, data):
for document in data:
Expand Down Expand Up @@ -1020,6 +1103,9 @@ def add_terms_data(self, terms):
self.add_match(match)
self.seen_values[field].append(bucket['key'])

def get_lookup_key(self,field):
return tuple(field) if type(field) == list else field


class CardinalityRule(RuleType):
""" A rule that matches if cardinality of a field is above or below a threshold within a timeframe """
Expand Down
20 changes: 20 additions & 0 deletions examples/ex_flatline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: freshemail debug rule
type: flatline
index: traces*
threshold: 3
# use_count_query: true
timestamp_field: timestamp
timeframe:
minutes: 1
filter:
- query:
query_string:
query: "*"
alert:
- "debug"
scan_entire_timeframe: true

realert:
minutes: 0
query_delay:
minutes: 3
Loading