-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored main branch #2
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,11 +89,7 @@ def _is_termination_msg_mathchat(message): | |
if message is None: | ||
return False | ||
cb = extract_code(message) | ||
contain_code = False | ||
for c in cb: | ||
if c[0] == "python" or c[0] == "wolfram": | ||
contain_code = True | ||
break | ||
contain_code = any(c[0] in ["python", "wolfram"] for c in cb) | ||
return not contain_code and get_answer(message) is not None and get_answer(message) != "" | ||
|
||
|
||
|
@@ -107,11 +103,7 @@ def _add_print_to_last_line(code): | |
last_line = lines[-1] | ||
if "\t" in last_line or "=" in last_line: | ||
return code | ||
if "=" in last_line: | ||
last_line = "print(" + last_line.split(" = ")[0] + ")" | ||
lines.append(last_line) | ||
else: | ||
lines[-1] = "print(" + last_line + ")" | ||
lines[-1] = f"print({last_line})" | ||
Comment on lines
-110
to
+106
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# 3. join the lines back together | ||
return "\n".join(lines) | ||
|
||
|
@@ -224,11 +216,11 @@ def execute_one_python_code(self, pycode): | |
is_success = return_code == 0 | ||
|
||
if not is_success: | ||
# Remove the file information from the error string | ||
pattern = r'File "/[^"]+\.py", line \d+, in .+\n' | ||
if isinstance(output, str): | ||
# Remove the file information from the error string | ||
pattern = r'File "/[^"]+\.py", line \d+, in .+\n' | ||
output = re.sub(pattern, "", output) | ||
output = "Error: " + output | ||
output = f"Error: {output}" | ||
Comment on lines
-227
to
+223
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
elif output == "": | ||
# Check if there is any print statement | ||
if "print" not in pycode: | ||
|
@@ -245,15 +237,13 @@ def execute_one_python_code(self, pycode): | |
if is_success: | ||
# remove print and check if it still works | ||
tmp = self._previous_code + "\n" + _remove_print(pycode) + "\n" | ||
rcode, _, _ = execute_code(tmp, **self._code_execution_config) | ||
else: | ||
# only add imports and check if it works | ||
tmp = self._previous_code + "\n" | ||
for line in pycode.split("\n"): | ||
if "import" in line: | ||
tmp += line + "\n" | ||
rcode, _, _ = execute_code(tmp, **self._code_execution_config) | ||
|
||
rcode, _, _ = execute_code(tmp, **self._code_execution_config) | ||
if rcode == 0: | ||
self._previous_code = tmp | ||
return output, is_success | ||
|
@@ -436,7 +426,7 @@ def run(self, query: str) -> str: | |
for result in res["pod"]: | ||
if result["@title"] == "Solution": | ||
answer = result["subpod"]["plaintext"] | ||
if result["@title"] == "Results" or result["@title"] == "Solutions": | ||
if result["@title"] in ["Results", "Solutions"]: | ||
Comment on lines
-439
to
+429
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
for i, sub in enumerate(result["subpod"]): | ||
answer += f"ans {i}: " + sub["plaintext"] + "\n" | ||
break | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,7 +191,12 @@ def create_qdrant_from_dir( | |
# Upsert in batch of 100 or less if the total number of chunks is less than 100 | ||
for i in range(0, len(chunks), min(100, len(chunks))): | ||
end_idx = i + min(100, len(chunks) - i) | ||
client.add(collection_name, documents=chunks[i:end_idx], ids=[j for j in range(i, end_idx)], parallel=parallel) | ||
client.add( | ||
collection_name, | ||
documents=chunks[i:end_idx], | ||
ids=list(range(i, end_idx)), | ||
parallel=parallel, | ||
) | ||
Comment on lines
-194
to
+199
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
# Create a payload index for the document field | ||
# Enables highly efficient payload filtering. Reference: https://qdrant.tech/documentation/concepts/indexing/#indexing | ||
|
@@ -259,8 +264,9 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore | |
else None, | ||
) | ||
|
||
data = { | ||
return { | ||
"ids": [[result.id for result in sublist] for sublist in results], | ||
"documents": [[result.document for result in sublist] for sublist in results], | ||
"documents": [ | ||
[result.document for result in sublist] for sublist in results | ||
], | ||
} | ||
return data | ||
Comment on lines
-262
to
-266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,7 +184,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = | |
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) | ||
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) | ||
self._context_max_tokens = self._max_tokens * 0.8 | ||
self._collection = True if self._docs_path is None else False # whether the collection is created | ||
self._collection = self._docs_path is None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
self._ipython = get_ipython() | ||
self._doc_idx = -1 # the index of the current used doc | ||
self._results = {} # the results of the current query | ||
|
@@ -207,12 +207,7 @@ def _is_termination_msg_retrievechat(self, message): | |
if message is None: | ||
return False | ||
cb = extract_code(message) | ||
contain_code = False | ||
for c in cb: | ||
# todo: support more languages | ||
if c[0] == "python": | ||
contain_code = True | ||
break | ||
contain_code = any(c[0] == "python" for c in cb) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
update_context_case1, update_context_case2 = self._check_update_context(message) | ||
return not (contain_code or update_context_case1 or update_context_case2) | ||
|
||
|
@@ -305,43 +300,46 @@ def _generate_retrieve_user_reply( | |
messages = self._oai_messages[sender] | ||
message = messages[-1] | ||
update_context_case1, update_context_case2 = self._check_update_context(message) | ||
if (update_context_case1 or update_context_case2) and self.update_context: | ||
print(colored("Updating context and resetting conversation.", "green"), flush=True) | ||
# extract the first sentence in the response as the intermediate answer | ||
_message = message.get("content", "").split("\n")[0].strip() | ||
_intermediate_info = re.split(r"(?<=[.!?])\s+", _message) | ||
self._intermediate_answers.add(_intermediate_info[0]) | ||
|
||
if update_context_case1: | ||
# try to get more context from the current retrieved doc results because the results may be too long to fit | ||
# in the LLM context. | ||
doc_contents = self._get_context(self._results) | ||
|
||
# Always use self.problem as the query text to retrieve docs, but each time we replace the context with the | ||
# next similar docs in the retrieved doc results. | ||
if not doc_contents: | ||
for _tmp_retrieve_count in range(1, 5): | ||
self._reset(intermediate=True) | ||
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
doc_contents = self._get_context(self._results) | ||
if doc_contents: | ||
break | ||
elif update_context_case2: | ||
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar | ||
# docs in the retrieved doc results to the context. | ||
for _tmp_retrieve_count in range(5): | ||
if ( | ||
not update_context_case1 | ||
and not update_context_case2 | ||
or not self.update_context | ||
): | ||
return False, None | ||
print(colored("Updating context and resetting conversation.", "green"), flush=True) | ||
# extract the first sentence in the response as the intermediate answer | ||
_message = message.get("content", "").split("\n")[0].strip() | ||
_intermediate_info = re.split(r"(?<=[.!?])\s+", _message) | ||
self._intermediate_answers.add(_intermediate_info[0]) | ||
|
||
if update_context_case1: | ||
# try to get more context from the current retrieved doc results because the results may be too long to fit | ||
# in the LLM context. | ||
doc_contents = self._get_context(self._results) | ||
|
||
# Always use self.problem as the query text to retrieve docs, but each time we replace the context with the | ||
# next similar docs in the retrieved doc results. | ||
if not doc_contents: | ||
for _tmp_retrieve_count in range(1, 5): | ||
self._reset(intermediate=True) | ||
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
self._get_context(self._results) | ||
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers) | ||
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
doc_contents = self._get_context(self._results) | ||
if doc_contents: | ||
break | ||
|
||
self.clear_history() | ||
sender.clear_history() | ||
return True, self._generate_message(doc_contents, task=self._task) | ||
else: | ||
return False, None | ||
elif update_context_case2: | ||
# Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar | ||
# docs in the retrieved doc results to the context. | ||
for _tmp_retrieve_count in range(5): | ||
self._reset(intermediate=True) | ||
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1)) | ||
self._get_context(self._results) | ||
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers) | ||
if doc_contents: | ||
break | ||
|
||
self.clear_history() | ||
sender.clear_history() | ||
return True, self._generate_message(doc_contents, task=self._task) | ||
Comment on lines
-308
to
+342
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): | ||
"""Retrieve docs based on the given problem and assign the results to the class property `_results`. | ||
|
@@ -405,8 +403,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string | |
self.problem = problem | ||
self.n_results = n_results | ||
doc_contents = self._get_context(self._results) | ||
message = self._generate_message(doc_contents, self._task) | ||
return message | ||
return self._generate_message(doc_contents, self._task) | ||
Comment on lines
-408
to
+406
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def run_code(self, code, **kwargs): | ||
lang = kwargs.get("lang", None) | ||
|
@@ -418,14 +415,13 @@ def run_code(self, code, **kwargs): | |
) | ||
if self._ipython is None or lang != "python": | ||
return super().run_code(code, **kwargs) | ||
else: | ||
result = self._ipython.run_cell(code) | ||
log = str(result.result) | ||
exitcode = 0 if result.success else 1 | ||
if result.error_before_exec is not None: | ||
log += f"\n{result.error_before_exec}" | ||
exitcode = 1 | ||
if result.error_in_exec is not None: | ||
log += f"\n{result.error_in_exec}" | ||
exitcode = 1 | ||
return exitcode, log, None | ||
result = self._ipython.run_cell(code) | ||
log = str(result.result) | ||
exitcode = 0 if result.success else 1 | ||
if result.error_before_exec is not None: | ||
log += f"\n{result.error_before_exec}" | ||
exitcode = 1 | ||
if result.error_in_exec is not None: | ||
log += f"\n{result.error_in_exec}" | ||
exitcode = 1 | ||
return exitcode, log, None | ||
Comment on lines
-421
to
+427
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
_is_termination_msg_mathchat
refactored with the following changes:use-any
)in
operator (merge-comparisons
)