Skip to content

Commit

Permalink
fix(iohandler): memorize unescaped quotes and escape them (#1058)
Browse files Browse the repository at this point in the history
  • Loading branch information
talboren authored Apr 7, 2024
1 parent 36f91b0 commit 6c278e4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 17 deletions.
61 changes: 46 additions & 15 deletions keep/iohandler/iohandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def extract_keep_functions(self, text):
while i < len(text):
if text[i : i + 5] == "keep.":
start = i
func_end = text.find("(", start)
if func_end > -1: # Opening '(' found after "keep."
i = func_end + 1 # Move i to the character after '('
paren_count = 1
func_start = text.find("(", start)
if func_start > -1: # Opening '(' found after "keep."
i = func_start + 1 # Move i to the character after '('
parent_count = 1
in_string = False
escape_next = False
quote_char = ""

while i < len(text) and (paren_count > 0 or in_string):
escapes = {}
while i < len(text) and (parent_count > 0 or in_string):
if text[i] == "\\" and in_string and not escape_next:
escape_next = True
i += 1
Expand All @@ -87,19 +87,28 @@ def extract_keep_functions(self, text):
if not in_string:
in_string = True
quote_char = text[i]
elif text[i] == quote_char and not escape_next:
elif (
text[i] == quote_char
and not escape_next
and str(text[i + 1]).isalpha()
== False # end of statement, arg, etc. if it's a letter, we need to escape it
):
in_string = False
quote_char = ""
elif text[i] == quote_char and not escape_next:
escapes[i] = text[
i
] # Save the quote character where we need to escape for valid ast parsing
elif text[i] == "(" and not in_string:
paren_count += 1
parent_count += 1
elif text[i] == ")" and not in_string:
paren_count -= 1
parent_count -= 1

escape_next = False
i += 1

if paren_count == 0:
matches.append(text[start:i])
if parent_count == 0:
matches.append((text[start:i], escapes))
continue # Skip the increment at the end of the loop to continue from the current position
else:
# If no '(' found, increment i to move past "keep."
Expand Down Expand Up @@ -154,8 +163,12 @@ def parse(self, string, safe=False, default=""):
if len(tokens) == 0:
return parsed_string
elif len(tokens) == 1:
token = "".join(tokens[0])
token, escapes = tokens[0]
token_to_replace = token
try:
if escapes:
for escape in escapes:
token = token[:escape] + "\\" + token[escape:]
val = self._parse_token(token)
except Exception as e:
# trim stacktrace since we have limitation on the error message
Expand All @@ -164,20 +177,24 @@ def parse(self, string, safe=False, default=""):
raise Exception(
f"Got {e.__class__.__name__} while parsing token '{trimmed_token}': {err_message}"
)
parsed_string = parsed_string.replace(token, str(val))
parsed_string = parsed_string.replace(token_to_replace, str(val))
return parsed_string
# this basically for complex expressions with functions and operators
for token in tokens:
token = "".join(token)
token, escapes = token
token_to_replace = token
try:
if escapes:
for escape in escapes:
token = token[:escape] + "\\" + token[escape:]
val = self._parse_token(token)
except Exception as e:
trimmed_token = self._trim_token_error(token)
err_message = str(e).splitlines()[-1]
raise Exception(
f"Got {e.__class__.__name__} while parsing token '{trimmed_token}': {err_message}"
)
parsed_string = parsed_string.replace(token, str(val))
parsed_string = parsed_string.replace(token_to_replace, str(val))

return parsed_string

Expand Down Expand Up @@ -394,3 +411,17 @@ def __get_short_urls(self, urls: list) -> dict:
)
except Exception:
self.logger.exception("Failed to request short URLs from API")


if __name__ == "__main__":
# debug & test
context_manager = ContextManager("keep")
context_manager.event_context = {
"notexist": "it actually exists",
"name": "this is a test",
}
iohandler = IOHandler(context_manager)
iohandler.parse(
"{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
safe=True,
)
27 changes: 25 additions & 2 deletions tests/test_iohandler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Test the io handler
"""

import datetime

import pytest
Expand Down Expand Up @@ -679,7 +680,6 @@ def test_complex_mixture(context_manager):
), "Should correctly handle a complex mixture of text and function calls."


"""
def test_escaped_quotes_inside_function_arguments(context_manager):
iohandler = IOHandler(context_manager)
template = "keep.split('some,string,with,escaped\\\\'quotes', ',')"
Expand All @@ -688,4 +688,27 @@ def test_escaped_quotes_inside_function_arguments(context_manager):
assert (
len(extracted_functions) == 1
), "Expected one function to be extracted with escaped quotes inside arguments."
"""


def test_if_else_in_template_existing(mocked_context_manager):
mocked_context_manager.get_full_context.return_value = {
"alert": {"notexist": "it actually exists", "name": "this is a test"}
}
iohandler = IOHandler(mocked_context_manager)
rendered = iohandler.render(
"{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
safe=True,
)
assert rendered == "it actually exists"


def test_if_else_in_template_not_existing(mocked_context_manager):
mocked_context_manager.get_full_context.return_value = {
"alert": {"name": "this is a test"}
}
iohandler = IOHandler(mocked_context_manager)
rendered = iohandler.render(
"{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
safe=True,
)
assert rendered == "this is a test"

0 comments on commit 6c278e4

Please sign in to comment.