Skip to content

Commit

Permalink
Adding IN operator support to CB connector (#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdazam1942 authored Mar 11, 2022
1 parent c24bd39 commit c233a80
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"ComparisonExpressionOperators.And": "and",
"ComparisonExpressionOperators.Or": "or",
"ComparisonComparators.Equal": ":",
"ComparisonComparators.In": ":",
"ComparisonComparators.NotEqual": ":",
"ComparisonComparators.GreaterThan": ":",
"ComparisonComparators.GreaterThanOrEqual": ":",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import json
import re
from stix_shifter_utils.utils import logger
from stix_shifter_utils.stix_translation.src.json_to_stix import observable

from stix_shifter_utils.stix_translation.src.patterns.pattern_objects import ObservationExpression, ComparisonExpression, \
ComparisonExpressionOperators, ComparisonComparators, Pattern, StartStopQualifier, \
CombinedComparisonExpression, CombinedObservationExpression, ObservationOperators
CombinedComparisonExpression, CombinedObservationExpression, ObservationOperators, SetValue
from stix_shifter_utils.stix_translation.src.patterns.errors import SearchFeatureNotSupportedError


Expand Down Expand Up @@ -47,7 +48,7 @@ def _format_gt(value) -> str:
if isinstance(value, int):
value = value + 1
return CbQueryStringPatternTranslator._format_gte(value)

@staticmethod
def _escape_value(value, comparator=None) -> str:
if isinstance(value, str):
Expand All @@ -57,7 +58,7 @@ def _escape_value(value, comparator=None) -> str:

@staticmethod
def _negate_comparison(comparison_string) -> str:
return "-({})".format(comparison_string)
return "-{}".format(comparison_string)

@staticmethod
def _to_cb_timestamp(ts: str) -> str:
Expand All @@ -70,10 +71,59 @@ def _format_start_stop_qualifier(self, expression, qualifier: StartStopQualifier
start = self._to_cb_timestamp(qualifier.start)
stop = self._to_cb_timestamp(qualifier.stop)

return "({}) and last_update:[{} TO {}]".format(expression, start, stop)
return "{} and last_update:[{} TO {}]".format(expression, start, stop)

@staticmethod
def _check_value_type(value):
"""
Determine the type (ipv4, ipv6, mac, date, etc) of the provided value.
See: https://github.com/opencybersecurityalliance/stix-shifter/blob/develop/stix_shifter_utils/stix_translation/src/json_to_stix/observable.py#L1
:param value: query value
:type value: int/str
:return: type of value
:rtype: str
"""
value = str(value)
for key, pattern in observable.REGEX.items():
if bool(re.search(pattern, value)):
return key
return None

def _parse_mapped_fields(self, value, comparator, mapped_fields_array) -> str:
"""Convert a list of mapped fields into a query string."""
comparison_strings = []
value_type = None
str_ = None

if isinstance(value, str):
value = [value]

for val in value:
value_type = self._check_value_type(val)

for mapped_field in mapped_fields_array:
# Only use the ipv4 fields when the value is an actual ipv4 address or range
skip = ('ipv4' in mapped_field and value_type not in ['ipv4', 'ipv4_cidr'])
# Only use the ipv6 fields when the value is an actual ipv6 address or range
skip = skip or ('ipv6' in mapped_field and value_type not in ['ipv6', 'ipv6_cidr'])

if not skip:
comparison_strings.append(f'{mapped_field}{comparator}{val}')

# Only wrap in () if there's more than one comparison string
if len(comparison_strings) == 1:
str_ = comparison_strings[0]
elif len(comparison_strings) > 1:
str_ = f"({' or '.join(comparison_strings)})"
else:
raise RuntimeError((f'Failed to convert {mapped_fields_array} mapped fields into query string'))

return str_

def _parse_expression(self, expression, qualifier=None):
if isinstance(expression, ComparisonExpression):
comparison_string = ""
# Base Case
# Resolve STIX Object Path to a field in the target Data Model
stix_object, stix_field = expression.object_path.split(':')
Expand All @@ -96,10 +146,17 @@ def _parse_expression(self, expression, qualifier=None):
value = self._format_lte(expression.value)
elif expression.comparator == ComparisonComparators.GreaterThan:
value = self._format_gt(expression.value)
elif (expression.comparator == ComparisonComparators.In and
isinstance(expression.value, SetValue)):
value = list(map(self._escape_value, expression.value.element_iterator()))
else:
value = self._escape_value(expression.value)

comparison_string = "{mapped_field}{comparator}{value}".format(mapped_field=mapped_field, comparator=comparator, value=value)
comparison_string = self._parse_mapped_fields(
value=value,
comparator=comparator,
mapped_fields_array=mapped_fields_array
)

# translate != to NOT equals
if expression.comparator == ComparisonComparators.NotEqual and not expression.negated:
Expand All @@ -123,9 +180,10 @@ def _parse_expression(self, expression, qualifier=None):

# Note: it seems the ordering of the expressions is reversed at a lower level
# so we reverse it here so that it is as expected.
query_string = (f1 + " {} " + f2).format(self._parse_expression(expression.expr2),
query_string = "({} {} {})".format(self._parse_expression(expression.expr2),
self.comparator_lookup[str(expression.operator)],
self._parse_expression(expression.expr1))

if qualifier is not None:
if isinstance(qualifier, StartStopQualifier):
return self._format_start_stop_qualifier(query_string, qualifier)
Expand All @@ -140,7 +198,7 @@ def _parse_expression(self, expression, qualifier=None):
operator = self.comparator_lookup[str(expression.operator)]
expr1 = self._parse_expression(expression.expr1, qualifier=qualifier)
expr2 = self._parse_expression(expression.expr2, qualifier=qualifier)
return f'({expr1}) {operator} ({expr2})'
return f'{expr1} {operator} {expr2}'
elif isinstance(expression, Pattern):
return self._parse_expression(expression.expression)
elif hasattr(expression, 'qualifier') and hasattr(expression, 'observation_expression'):
Expand All @@ -151,7 +209,7 @@ def _parse_expression(self, expression, qualifier=None):

def _add_default_timerange(self, query):
if self.time_range and 'last_update' not in query:
query = "(({}) and last_update:-{}m)".format(query, self.time_range)
query = "{} and last_update:-{}m".format(query, self.time_range)

return query

Expand Down
Loading

0 comments on commit c233a80

Please sign in to comment.