diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index b270d5a7..94f8b7d5 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -62,6 +62,10 @@ def remove_markdown(self, query: str) -> str: return matches[0].strip() return query + def format_sql_query_intermediate_steps(self, step: str) -> str: + pattern = r"```sql(.*?)```" + return re.sub(pattern, self.format_sql_query, step) + @classmethod def get_upper_bound_limit(cls) -> int: top_k = os.getenv("UPPER_LIMIT_QUERY_RETURN_ROWS", None) @@ -170,12 +174,19 @@ def stream_agent_steps( # noqa: C901 ): if "actions" in chunk: for message in chunk["messages"]: - queue.put(message.content + "\n") + queue.put( + self.format_sql_query_intermediate_steps( + message.content + ) + + "\n" + ) elif "steps" in chunk: for step in chunk["steps"]: queue.put(f"\n**Observation:**\n {step.observation}\n") elif "output" in chunk: - queue.put(f'\n**Final Answer:**\n {chunk["output"]}') + queue.put( + f'\n**Final Answer:**\n {self.format_sql_query_intermediate_steps(chunk["output"])}' + ) if "```sql" in chunk["output"]: response.sql = replace_unprocessable_characters( self.remove_markdown(chunk["output"])