From b230fb06f0a191da42891cda528fcf3a2cac8893 Mon Sep 17 00:00:00 2001
From: Tal Borenstein <tal@keephq.dev>
Date: Sun, 7 Apr 2024 14:57:11 +0300
Subject: [PATCH 1/2] fix(iohandler): memorize unescaped quotes and escape them

---
 keep/iohandler/iohandler.py          | 61 +++++++++++++++++++++-------
 keep/providers/base/base_provider.py | 12 ++++++
 tests/test_iohandler.py              | 27 +++++++++++-
 3 files changed, 83 insertions(+), 17 deletions(-)

diff --git a/keep/iohandler/iohandler.py b/keep/iohandler/iohandler.py
index 855bb81568..968f503c8e 100644
--- a/keep/iohandler/iohandler.py
+++ b/keep/iohandler/iohandler.py
@@ -70,15 +70,15 @@ def extract_keep_functions(self, text):
         while i < len(text):
             if text[i : i + 5] == "keep.":
                 start = i
-                func_end = text.find("(", start)
-                if func_end > -1:  # Opening '(' found after "keep."
-                    i = func_end + 1  # Move i to the character after '('
-                    paren_count = 1
+                func_start = text.find("(", start)
+                if func_start > -1:  # Opening '(' found after "keep."
+                    i = func_start + 1  # Move i to the character after '('
+                    parent_count = 1
                     in_string = False
                     escape_next = False
                     quote_char = ""
-
-                    while i < len(text) and (paren_count > 0 or in_string):
+                    escapes = {}
+                    while i < len(text) and (parent_count > 0 or in_string):
                         if text[i] == "\\" and in_string and not escape_next:
                             escape_next = True
                             i += 1
@@ -87,19 +87,28 @@ def extract_keep_functions(self, text):
                             if not in_string:
                                 in_string = True
                                 quote_char = text[i]
-                            elif text[i] == quote_char and not escape_next:
+                            elif (
+                                text[i] == quote_char
+                                and not escape_next
+                                and str(text[i + 1]).isalpha()
+                                == False  # end of statement, arg, etc. if it's a letter, we need to escape it
+                            ):
                                 in_string = False
                                 quote_char = ""
+                            elif text[i] == quote_char and not escape_next:
+                                escapes[i] = text[
+                                    i
+                                ]  # Save the quote character where we need to escape for valid ast parsing
                         elif text[i] == "(" and not in_string:
-                            paren_count += 1
+                            parent_count += 1
                         elif text[i] == ")" and not in_string:
-                            paren_count -= 1
+                            parent_count -= 1
 
                         escape_next = False
                         i += 1
 
-                    if paren_count == 0:
-                        matches.append(text[start:i])
+                    if parent_count == 0:
+                        matches.append((text[start:i], escapes))
                     continue  # Skip the increment at the end of the loop to continue from the current position
                 else:
                     # If no '(' found, increment i to move past "keep."
@@ -154,8 +163,12 @@ def parse(self, string, safe=False, default=""):
         if len(tokens) == 0:
             return parsed_string
         elif len(tokens) == 1:
-            token = "".join(tokens[0])
+            token, escapes = tokens[0]
+            token_to_replace = token
             try:
+                if escapes:
+                    for escape in escapes:
+                        token = token[:escape] + "\\" + token[escape:]
                 val = self._parse_token(token)
             except Exception as e:
                 # trim stacktrace since we have limitation on the error message
@@ -164,12 +177,16 @@ def parse(self, string, safe=False, default=""):
                 raise Exception(
                     f"Got {e.__class__.__name__} while parsing token '{trimmed_token}': {err_message}"
                 )
-            parsed_string = parsed_string.replace(token, str(val))
+            parsed_string = parsed_string.replace(token_to_replace, str(val))
             return parsed_string
         # this basically for complex expressions with functions and operators
         for token in tokens:
-            token = "".join(token)
+            token, escapes = token
+            token_to_replace = token
             try:
+                if escapes:
+                    for escape in escapes:
+                        token = token[:escape] + "\\" + token[escape:]
                 val = self._parse_token(token)
             except Exception as e:
                 trimmed_token = self._trim_token_error(token)
@@ -177,7 +194,7 @@ def parse(self, string, safe=False, default=""):
                 raise Exception(
                     f"Got {e.__class__.__name__} while parsing token '{trimmed_token}': {err_message}"
                 )
-            parsed_string = parsed_string.replace(token, str(val))
+            parsed_string = parsed_string.replace(token_to_replace, str(val))
 
         return parsed_string
 
@@ -394,3 +411,17 @@ def __get_short_urls(self, urls: list) -> dict:
                 )
         except Exception:
             self.logger.exception("Failed to request short URLs from API")
+
+
+if __name__ == "__main__":
+    # debug & test
+    context_manager = ContextManager("keep")
+    context_manager.event_context = {
+        "notexist": "it actually exists",
+        "name": "this is a test",
+    }
+    iohandler = IOHandler(context_manager)
+    iohandler.parse(
+        "{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
+        safe=True,
+    )
diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py
index 50a0096f22..bcd334f7a5 100644
--- a/keep/providers/base/base_provider.py
+++ b/keep/providers/base/base_provider.py
@@ -185,6 +185,18 @@ def _enrich_alert(self, enrichments, results):
         self.logger.info("Enriching alert", extra={"fingerprint": fingerprint})
         try:
             enrich_alert(self.context_manager.tenant_id, fingerprint, _enrichments)
+            if self.context_manager.event_context:
+                for enrichment in _enrichments:
+                    if isinstance(self.context_manager.event_context, dict):
+                        self.context_manager.event_context[enrichment] = _enrichments[
+                            enrichment
+                        ]
+                    else:
+                        setattr(
+                            self.context_manager.event_context,
+                            enrichment,
+                            _enrichments[enrichment],
+                        )
         except Exception as e:
             self.logger.error(
                 "Failed to enrich alert in db",
diff --git a/tests/test_iohandler.py b/tests/test_iohandler.py
index 83012115bb..0ebe356cd4 100644
--- a/tests/test_iohandler.py
+++ b/tests/test_iohandler.py
@@ -1,6 +1,7 @@
 """
 Test the io handler
 """
+
 import datetime
 
 import pytest
@@ -679,7 +680,6 @@ def test_complex_mixture(context_manager):
     ), "Should correctly handle a complex mixture of text and function calls."
 
 
-"""
 def test_escaped_quotes_inside_function_arguments(context_manager):
     iohandler = IOHandler(context_manager)
     template = "keep.split('some,string,with,escaped\\\\'quotes', ',')"
@@ -688,4 +688,27 @@ def test_escaped_quotes_inside_function_arguments(context_manager):
     assert (
         len(extracted_functions) == 1
     ), "Expected one function to be extracted with escaped quotes inside arguments."
-"""
+
+
+def test_if_else_in_template_existing(mocked_context_manager):
+    mocked_context_manager.get_full_context.return_value = {
+        "alert": {"notexist": "it actually exists", "name": "this is a test"}
+    }
+    iohandler = IOHandler(mocked_context_manager)
+    rendered = iohandler.render(
+        "{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
+        safe=True,
+    )
+    assert rendered == "it actually exists"
+
+
+def test_if_else_in_template_not_existing(mocked_context_manager):
+    mocked_context_manager.get_full_context.return_value = {
+        "alert": {"name": "this is a test"}
+    }
+    iohandler = IOHandler(mocked_context_manager)
+    rendered = iohandler.render(
+        "{{#alert.notexist}}{{.}}{{/alert.notexist}}{{^alert.notexist}}{{alert.name}}{{/alert.notexist}}",
+        safe=True,
+    )
+    assert rendered == "this is a test"

From 313ff263cbdd14a27071cc560f3a467ba07384a5 Mon Sep 17 00:00:00 2001
From: Tal Borenstein <tal@keephq.dev>
Date: Sun, 7 Apr 2024 15:09:07 +0300
Subject: [PATCH 2/2] fix: remove changes in base provider

---
 keep/providers/base/base_provider.py | 12 ------------
 1 file changed, 12 deletions(-)

diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py
index bcd334f7a5..50a0096f22 100644
--- a/keep/providers/base/base_provider.py
+++ b/keep/providers/base/base_provider.py
@@ -185,18 +185,6 @@ def _enrich_alert(self, enrichments, results):
         self.logger.info("Enriching alert", extra={"fingerprint": fingerprint})
         try:
             enrich_alert(self.context_manager.tenant_id, fingerprint, _enrichments)
-            if self.context_manager.event_context:
-                for enrichment in _enrichments:
-                    if isinstance(self.context_manager.event_context, dict):
-                        self.context_manager.event_context[enrichment] = _enrichments[
-                            enrichment
-                        ]
-                    else:
-                        setattr(
-                            self.context_manager.event_context,
-                            enrichment,
-                            _enrichments[enrichment],
-                        )
         except Exception as e:
             self.logger.error(
                 "Failed to enrich alert in db",