From 2a4165d525883f78b4c719fa135a4f21ceb9d1b7 Mon Sep 17 00:00:00 2001 From: Sourcery AI <> Date: Thu, 2 Nov 2023 08:24:15 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- .../contrib/math_user_proxy_agent.py | 24 +--- .../qdrant_retrieve_user_proxy_agent.py | 14 +- .../contrib/retrieve_user_proxy_agent.py | 104 +++++++-------- autogen/agentchat/contrib/teachable_agent.py | 93 +++++++------ .../agentchat/contrib/text_analyzer_agent.py | 3 +- autogen/agentchat/conversable_agent.py | 123 ++++++++---------- autogen/agentchat/groupchat.py | 9 +- autogen/code_utils.py | 13 +- autogen/math_utils.py | 59 ++++----- autogen/oai/completion.py | 9 +- autogen/oai/openai_utils.py | 6 +- autogen/retrieve_utils.py | 10 +- autogen/token_count_utils.py | 15 ++- test/agentchat/chat_with_teachable_agent.py | 9 +- test/agentchat/extensions/tsp.py | 13 +- test/agentchat/test_async.py | 3 +- test/agentchat/test_teachable_agent.py | 9 +- test/oai/test_completion.py | 4 +- 18 files changed, 245 insertions(+), 275 deletions(-) diff --git a/autogen/agentchat/contrib/math_user_proxy_agent.py b/autogen/agentchat/contrib/math_user_proxy_agent.py index f7557517da81..4f035c56e45c 100644 --- a/autogen/agentchat/contrib/math_user_proxy_agent.py +++ b/autogen/agentchat/contrib/math_user_proxy_agent.py @@ -89,11 +89,7 @@ def _is_termination_msg_mathchat(message): if message is None: return False cb = extract_code(message) - contain_code = False - for c in cb: - if c[0] == "python" or c[0] == "wolfram": - contain_code = True - break + contain_code = any(c[0] in ["python", "wolfram"] for c in cb) return not contain_code and get_answer(message) is not None and get_answer(message) != "" @@ -107,11 +103,7 @@ def _add_print_to_last_line(code): last_line = lines[-1] if "\t" in last_line or "=" in last_line: return code - if "=" in last_line: - last_line = "print(" + last_line.split(" = ")[0] + ")" - lines.append(last_line) - else: - lines[-1] = "print(" + last_line + ")" + lines[-1] = f"print({last_line})" # 3. join the lines back together return "\n".join(lines) @@ -224,11 +216,11 @@ def execute_one_python_code(self, pycode): is_success = return_code == 0 if not is_success: - # Remove the file information from the error string - pattern = r'File "/[^"]+\.py", line \d+, in .+\n' if isinstance(output, str): + # Remove the file information from the error string + pattern = r'File "/[^"]+\.py", line \d+, in .+\n' output = re.sub(pattern, "", output) - output = "Error: " + output + output = f"Error: {output}" elif output == "": # Check if there is any print statement if "print" not in pycode: @@ -245,15 +237,13 @@ def execute_one_python_code(self, pycode): if is_success: # remove print and check if it still works tmp = self._previous_code + "\n" + _remove_print(pycode) + "\n" - rcode, _, _ = execute_code(tmp, **self._code_execution_config) else: # only add imports and check if it works tmp = self._previous_code + "\n" for line in pycode.split("\n"): if "import" in line: tmp += line + "\n" - rcode, _, _ = execute_code(tmp, **self._code_execution_config) - + rcode, _, _ = execute_code(tmp, **self._code_execution_config) if rcode == 0: self._previous_code = tmp return output, is_success @@ -436,7 +426,7 @@ def run(self, query: str) -> str: for result in res["pod"]: if result["@title"] == "Solution": answer = result["subpod"]["plaintext"] - if result["@title"] == "Results" or result["@title"] == "Solutions": + if result["@title"] in ["Results", "Solutions"]: for i, sub in enumerate(result["subpod"]): answer += f"ans {i}: " + sub["plaintext"] + "\n" break diff --git a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py index e0bb8d8216f0..24ef2a843e41 100644 --- a/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py @@ -191,7 +191,12 @@ def create_qdrant_from_dir( # Upsert in batch of 100 or less if the total number of chunks is less than 100 for i in range(0, len(chunks), min(100, len(chunks))): end_idx = i + min(100, len(chunks) - i) - client.add(collection_name, documents=chunks[i:end_idx], ids=[j for j in range(i, end_idx)], parallel=parallel) + client.add( + collection_name, + documents=chunks[i:end_idx], + ids=list(range(i, end_idx)), + parallel=parallel, + ) # Create a payload index for the document field # Enables highly efficient payload filtering. Reference: https://qdrant.tech/documentation/concepts/indexing/#indexing @@ -259,8 +264,9 @@ class QueryResponse(BaseModel, extra="forbid"): # type: ignore else None, ) - data = { + return { "ids": [[result.id for result in sublist] for sublist in results], - "documents": [[result.document for result in sublist] for sublist in results], + "documents": [ + [result.document for result in sublist] for sublist in results + ], } - return data diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index b24249bbe961..49fc021fbb44 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -184,7 +184,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token) self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None) self._context_max_tokens = self._max_tokens * 0.8 - self._collection = True if self._docs_path is None else False # whether the collection is created + self._collection = self._docs_path is None self._ipython = get_ipython() self._doc_idx = -1 # the index of the current used doc self._results = {} # the results of the current query @@ -207,12 +207,7 @@ def _is_termination_msg_retrievechat(self, message): if message is None: return False cb = extract_code(message) - contain_code = False - for c in cb: - # todo: support more languages - if c[0] == "python": - contain_code = True - break + contain_code = any(c[0] == "python" for c in cb) update_context_case1, update_context_case2 = self._check_update_context(message) return not (contain_code or update_context_case1 or update_context_case2) @@ -305,43 +300,46 @@ def _generate_retrieve_user_reply( messages = self._oai_messages[sender] message = messages[-1] update_context_case1, update_context_case2 = self._check_update_context(message) - if (update_context_case1 or update_context_case2) and self.update_context: - print(colored("Updating context and resetting conversation.", "green"), flush=True) - # extract the first sentence in the response as the intermediate answer - _message = message.get("content", "").split("\n")[0].strip() - _intermediate_info = re.split(r"(?<=[.!?])\s+", _message) - self._intermediate_answers.add(_intermediate_info[0]) - - if update_context_case1: - # try to get more context from the current retrieved doc results because the results may be too long to fit - # in the LLM context. - doc_contents = self._get_context(self._results) - - # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the - # next similar docs in the retrieved doc results. - if not doc_contents: - for _tmp_retrieve_count in range(1, 5): - self._reset(intermediate=True) - self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1)) - doc_contents = self._get_context(self._results) - if doc_contents: - break - elif update_context_case2: - # Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar - # docs in the retrieved doc results to the context. - for _tmp_retrieve_count in range(5): + if ( + not update_context_case1 + and not update_context_case2 + or not self.update_context + ): + return False, None + print(colored("Updating context and resetting conversation.", "green"), flush=True) + # extract the first sentence in the response as the intermediate answer + _message = message.get("content", "").split("\n")[0].strip() + _intermediate_info = re.split(r"(?<=[.!?])\s+", _message) + self._intermediate_answers.add(_intermediate_info[0]) + + if update_context_case1: + # try to get more context from the current retrieved doc results because the results may be too long to fit + # in the LLM context. + doc_contents = self._get_context(self._results) + + # Always use self.problem as the query text to retrieve docs, but each time we replace the context with the + # next similar docs in the retrieved doc results. + if not doc_contents: + for _tmp_retrieve_count in range(1, 5): self._reset(intermediate=True) - self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1)) - self._get_context(self._results) - doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers) + self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1)) + doc_contents = self._get_context(self._results) if doc_contents: break - - self.clear_history() - sender.clear_history() - return True, self._generate_message(doc_contents, task=self._task) - else: - return False, None + elif update_context_case2: + # Use the current intermediate info as the query text to retrieve docs, and each time we append the top similar + # docs in the retrieved doc results to the context. + for _tmp_retrieve_count in range(5): + self._reset(intermediate=True) + self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1)) + self._get_context(self._results) + doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers) + if doc_contents: + break + + self.clear_history() + sender.clear_history() + return True, self._generate_message(doc_contents, task=self._task) def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""): """Retrieve docs based on the given problem and assign the results to the class property `_results`. @@ -405,8 +403,7 @@ def generate_init_message(self, problem: str, n_results: int = 20, search_string self.problem = problem self.n_results = n_results doc_contents = self._get_context(self._results) - message = self._generate_message(doc_contents, self._task) - return message + return self._generate_message(doc_contents, self._task) def run_code(self, code, **kwargs): lang = kwargs.get("lang", None) @@ -418,14 +415,13 @@ def run_code(self, code, **kwargs): ) if self._ipython is None or lang != "python": return super().run_code(code, **kwargs) - else: - result = self._ipython.run_cell(code) - log = str(result.result) - exitcode = 0 if result.success else 1 - if result.error_before_exec is not None: - log += f"\n{result.error_before_exec}" - exitcode = 1 - if result.error_in_exec is not None: - log += f"\n{result.error_in_exec}" - exitcode = 1 - return exitcode, log, None + result = self._ipython.run_cell(code) + log = str(result.result) + exitcode = 0 if result.success else 1 + if result.error_before_exec is not None: + log += f"\n{result.error_before_exec}" + exitcode = 1 + if result.error_in_exec is not None: + log += f"\n{result.error_in_exec}" + exitcode = 1 + return exitcode, log, None diff --git a/autogen/agentchat/contrib/teachable_agent.py b/autogen/agentchat/contrib/teachable_agent.py index 8db5b699ea2f..e185ce93ab6e 100644 --- a/autogen/agentchat/contrib/teachable_agent.py +++ b/autogen/agentchat/contrib/teachable_agent.py @@ -248,7 +248,7 @@ def concatenate_memo_texts(self, memo_list): if len(memo_list) > 0: info = "\n# Memories that might help\n" for memo in memo_list: - info = info + "- " + memo + "\n" + info = f"{info}- {memo}" + "\n" if self.verbosity >= 1: print(colored("\nMEMOS APPENDED TO LAST USER MESSAGE...\n" + info + "\n", "light_yellow")) memo_texts = memo_texts + "\n" + info @@ -256,17 +256,16 @@ def concatenate_memo_texts(self, memo_list): def analyze(self, text_to_analyze, analysis_instructions): """Asks TextAnalyzerAgent to analyze the given text according to specific instructions.""" - if self.verbosity >= 2: - # Use the messaging mechanism so that the analyzer's messages are included in the printed chat. - self.analyzer.reset() # Clear the analyzer's list of messages. - self.send( - recipient=self.analyzer, message=text_to_analyze, request_reply=False - ) # Put the message in the analyzer's list. - self.send(recipient=self.analyzer, message=analysis_instructions, request_reply=True) # Request the reply. - return self.last_message(self.analyzer)["content"] - else: + if self.verbosity < 2: # Use the analyzer's method directly, to leave analyzer message out of the printed chat. return self.analyzer.analyze_text(text_to_analyze, analysis_instructions) + # Use the messaging mechanism so that the analyzer's messages are included in the printed chat. + self.analyzer.reset() # Clear the analyzer's list of messages. + self.send( + recipient=self.analyzer, message=text_to_analyze, request_reply=False + ) # Put the message in the analyzer's list. + self.send(recipient=self.analyzer, message=analysis_instructions, request_reply=True) # Request the reply. + return self.last_message(self.analyzer)["content"] class MemoStore: @@ -303,7 +302,7 @@ def __init__(self, verbosity, reset, path_to_db_dir): self.last_memo_id = 0 if (not reset) and os.path.exists(self.path_to_dict): print(colored("\nLOADING MEMORY FROM DISK", "light_green")) - print(colored(" Location = {}".format(self.path_to_dict), "light_green")) + print(colored(f" Location = {self.path_to_dict}", "light_green")) with open(self.path_to_dict, "rb") as f: self.uid_text_dict = pickle.load(f) self.last_memo_id = len(self.uid_text_dict) @@ -317,7 +316,7 @@ def list_memos(self): input_text, output_text = text print( colored( - " ID: {}\n INPUT TEXT: {}\n OUTPUT TEXT: {}".format(uid, input_text, output_text), + f" ID: {uid}\n INPUT TEXT: {input_text}\n OUTPUT TEXT: {output_text}", "light_green", ) ) @@ -325,7 +324,7 @@ def list_memos(self): def close(self): """Saves self.uid_text_dict to disk.""" print(colored("\nSAVING MEMORY TO DISK", "light_green")) - print(colored(" Location = {}".format(self.path_to_dict), "light_green")) + print(colored(f" Location = {self.path_to_dict}", "light_green")) with open(self.path_to_dict, "wb") as file: pickle.dump(self.uid_text_dict, file) @@ -344,9 +343,7 @@ def add_input_output_pair(self, input_text, output_text): if self.verbosity >= 1: print( colored( - "\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}".format( - self.last_memo_id, input_text, output_text - ), + f"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {self.last_memo_id}\n INPUT\n {input_text}\n OUTPUT\n {output_text}", "light_green", ) ) @@ -362,9 +359,7 @@ def get_nearest_memo(self, query_text): if self.verbosity >= 1: print( colored( - "\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format( - input_text, output_text, distance - ), + f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}", "light_green", ) ) @@ -372,8 +367,7 @@ def get_nearest_memo(self, query_text): def get_related_memos(self, query_text, n_results, threshold): """Retrieves memos that are related to the given query text within the specified distance threshold.""" - if n_results > len(self.uid_text_dict): - n_results = len(self.uid_text_dict) + n_results = min(n_results, len(self.uid_text_dict)) results = self.vec_db.query(query_texts=[query_text], n_results=n_results) memos = [] num_results = len(results["ids"][0]) @@ -385,9 +379,7 @@ def get_related_memos(self, query_text, n_results, threshold): if self.verbosity >= 1: print( colored( - "\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format( - input_text, output_text, distance - ), + f"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {input_text}\n OUTPUT\n {output_text}\n DISTANCE\n {distance}", "light_green", ) ) @@ -398,28 +390,47 @@ def prepopulate(self): """Adds a few arbitrary examples to the vector DB, just to make retrieval less trivial.""" if self.verbosity >= 1: print(colored("\nPREPOPULATING MEMORY", "light_green")) - examples = [] - examples.append({"text": "When I say papers I mean research papers, which are typically pdfs.", "label": "yes"}) - examples.append({"text": "Please verify that each paper you listed actually uses langchain.", "label": "no"}) - examples.append({"text": "Tell gpt the output should still be latex code.", "label": "no"}) - examples.append({"text": "Hint: convert pdfs to text and then answer questions based on them.", "label": "yes"}) - examples.append( - {"text": "To create a good PPT, include enough content to make it interesting.", "label": "yes"} - ) - examples.append( + examples = [ + { + "text": "When I say papers I mean research papers, which are typically pdfs.", + "label": "yes", + }, + { + "text": "Please verify that each paper you listed actually uses langchain.", + "label": "no", + }, + { + "text": "Tell gpt the output should still be latex code.", + "label": "no", + }, + { + "text": "Hint: convert pdfs to text and then answer questions based on them.", + "label": "yes", + }, + { + "text": "To create a good PPT, include enough content to make it interesting.", + "label": "yes", + }, { "text": "No, for this case the columns should be aspects and the rows should be frameworks.", "label": "no", - } - ) - examples.append({"text": "When writing code, remember to include any libraries that are used.", "label": "yes"}) - examples.append({"text": "Please summarize the papers by Eric Horvitz on bounded rationality.", "label": "no"}) - examples.append({"text": "Compare the h-index of Daniel Weld and Oren Etzioni.", "label": "no"}) - examples.append( + }, + { + "text": "When writing code, remember to include any libraries that are used.", + "label": "yes", + }, + { + "text": "Please summarize the papers by Eric Horvitz on bounded rationality.", + "label": "no", + }, + { + "text": "Compare the h-index of Daniel Weld and Oren Etzioni.", + "label": "no", + }, { "text": "Double check to be sure that the columns in a table correspond to what was asked for.", "label": "yes", - } - ) + }, + ] for example in examples: self.add_input_output_pair(example["text"], example["label"]) diff --git a/autogen/agentchat/contrib/text_analyzer_agent.py b/autogen/agentchat/contrib/text_analyzer_agent.py index 8cf88eba6aed..cc8e12455a66 100644 --- a/autogen/agentchat/contrib/text_analyzer_agent.py +++ b/autogen/agentchat/contrib/text_analyzer_agent.py @@ -78,5 +78,4 @@ def analyze_text(self, text_to_analyze, analysis_instructions): # Generate and return the analysis string. response = oai.ChatCompletion.create(context=None, messages=messages, **self.llm_config) - output_text = oai.ChatCompletion.extract_text_or_function_call(response)[0] - return output_text + return oai.ChatCompletion.extract_text_or_function_call(response)[0] diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 017ba4e848ac..2927ff3ab417 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -253,10 +253,7 @@ def _message_to_dict(message: Union[Dict, str]): The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary. """ - if isinstance(message, str): - return {"content": message} - else: - return message + return {"content": message} if isinstance(message, str) else message def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: Agent) -> bool: """Append a message to the ChatCompletion conversation. @@ -328,10 +325,7 @@ def send( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. """ - # When the agent composes and sends the message, the role of the message is "assistant" - # unless it's "function". - valid = self._append_oai_message(message, "assistant", recipient) - if valid: + if valid := self._append_oai_message(message, "assistant", recipient): recipient.receive(message, self, request_reply, silent) else: raise ValueError( @@ -377,10 +371,7 @@ async def a_send( Raises: ValueError: if the message can't be converted into a valid ChatCompletion message. """ - # When the agent composes and sends the message, the role of the message is "assistant" - # unless it's "function". - valid = self._append_oai_message(message, "assistant", recipient) - if valid: + if valid := self._append_oai_message(message, "assistant", recipient): await recipient.a_receive(message, self, request_reply, silent) else: raise ValueError( @@ -705,32 +696,31 @@ def check_termination_and_human_reply( no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" # if the human input is empty, and the message is a termination message, then we will terminate the conversation reply = reply if reply or not self._is_termination_msg(message) else "exit" - else: - if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - terminate = self._is_termination_msg(message) - reply = self.get_human_input( - f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " - if terminate - else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not terminate else "exit" - elif self._is_termination_msg(message): - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - reply = self.get_human_input( - f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply or "exit" + elif self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = self.get_human_input( + f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = self.get_human_input( + f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply or "exit" # print the no_human_input_msg if no_human_input_msg: @@ -776,32 +766,31 @@ async def a_check_termination_and_human_reply( no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" # if the human input is empty, and the message is a termination message, then we will terminate the conversation reply = reply if reply or not self._is_termination_msg(message) else "exit" - else: - if self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - terminate = self._is_termination_msg(message) - reply = await self.a_get_human_input( - f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " - if terminate - else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply if reply or not terminate else "exit" - elif self._is_termination_msg(message): - if self.human_input_mode == "NEVER": - reply = "exit" - else: - # self.human_input_mode == "TERMINATE": - reply = await self.a_get_human_input( - f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " - ) - no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" - # if the human input is empty, and the message is a termination message, then we will terminate the conversation - reply = reply or "exit" + elif self._consecutive_auto_reply_counter[sender] >= self._max_consecutive_auto_reply_dict[sender]: + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + terminate = self._is_termination_msg(message) + reply = await self.a_get_human_input( + f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " + if terminate + else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply if reply or not terminate else "exit" + elif self._is_termination_msg(message): + if self.human_input_mode == "NEVER": + reply = "exit" + else: + # self.human_input_mode == "TERMINATE": + reply = await self.a_get_human_input( + f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: " + ) + no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else "" + # if the human input is empty, and the message is a termination message, then we will terminate the conversation + reply = reply or "exit" # print the no_human_input_msg if no_human_input_msg: @@ -959,8 +948,7 @@ def get_human_input(self, prompt: str) -> str: Returns: str: human input. """ - reply = input(prompt) - return reply + return input(prompt) async def a_get_human_input(self, prompt: str) -> str: """(Async) Get human input. @@ -973,8 +961,7 @@ async def a_get_human_input(self, prompt: str) -> str: Returns: str: human input. """ - reply = input(prompt) - return reply + return input(prompt) def run_code(self, code, **kwargs): """Run the code and return the result. diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index f1c549bc18b2..e6420fd0ca8c 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -45,11 +45,10 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: """Return the next agent in the list.""" if agents == self.agents: return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] - else: - offset = self.agent_names.index(agent.name) + 1 - for i in range(len(self.agents)): - if self.agents[(offset + i) % len(self.agents)] in agents: - return self.agents[(offset + i) % len(self.agents)] + offset = self.agent_names.index(agent.name) + 1 + for i in range(len(self.agents)): + if self.agents[(offset + i) % len(self.agents)] in agents: + return self.agents[(offset + i) % len(self.agents)] def select_speaker_msg(self, agents: List[Agent]): """Return the message for selecting the next speaker.""" diff --git a/autogen/code_utils.py b/autogen/code_utils.py index caaf09072850..bb96d9692cc1 100644 --- a/autogen/code_utils.py +++ b/autogen/code_utils.py @@ -24,7 +24,7 @@ TIMEOUT_MSG = "Timeout" DEFAULT_TIMEOUT = 600 WIN32 = sys.platform == "win32" -PATH_SEPARATOR = WIN32 and "\\" or "/" +PATH_SEPARATOR = "\\" if WIN32 else "/" logger = logging.getLogger(__name__) @@ -261,9 +261,6 @@ def execute_code( logger.error(error_msg) raise AssertionError(error_msg) - # Warn if use_docker was unspecified (or None), and cannot be provided (the default). - # In this case the current behavior is to fall back to run natively, but this behavior - # is subject to change. if use_docker is None: if docker is None: use_docker = False @@ -327,7 +324,7 @@ def execute_code( logs = result.stderr if original_filename is None: abs_path = str(pathlib.Path(filepath).absolute()) - logs = logs.replace(str(abs_path), "").replace(filename, "") + logs = logs.replace(abs_path, "").replace(filename, "") else: abs_path = str(pathlib.Path(work_dir).absolute()) + PATH_SEPARATOR logs = logs.replace(str(abs_path), "") @@ -445,9 +442,7 @@ def _remove_check(response): """Remove the check function from the response.""" # find the position of the check function pos = response.find("def check(") - if pos == -1: - return response - return response[:pos] + return response if pos == -1 else response[:pos] def eval_function_completions( @@ -488,7 +483,7 @@ def eval_function_completions( success_list.append(success) return { "expected_success": 1 - pow(1 - sum(success_list) / n, n), - "success": any(s for s in success_list), + "success": any(success_list), } if callable(assertions) and n > 1: # assertion generator diff --git a/autogen/math_utils.py b/autogen/math_utils.py index 7f35470fa6b9..74f3faac6c42 100644 --- a/autogen/math_utils.py +++ b/autogen/math_utils.py @@ -50,8 +50,8 @@ def last_boxed_only_string(string: str) -> Optional[str]: idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") - if idx < 0: - return None + if idx < 0: + return None i = idx right_brace_idx = None @@ -66,12 +66,7 @@ def last_boxed_only_string(string: str) -> Optional[str]: break i += 1 - if right_brace_idx is None: - retval = None - else: - retval = string[idx : right_brace_idx + 1] - - return retval + return None if right_brace_idx is None else string[idx : right_brace_idx + 1] def _fix_fracs(string: str) -> str: @@ -95,26 +90,24 @@ def _fix_fracs(string: str) -> str: new_str += substr else: try: - if not len(substr) >= 2: + if len(substr) < 2: raise AssertionError except Exception: return string a = substr[0] b = substr[1] if b != "{": - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr - else: - new_str += "{" + a + "}{" + b + "}" + new_str += ( + "{" + a + "}{" + b + "}" + substr[2:] + if len(substr) > 2 + else "{" + a + "}{" + b + "}" + ) + elif len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr else: - if len(substr) > 2: - post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr - else: - new_str += "{" + a + "}" + b - string = new_str - return string + new_str += "{" + a + "}" + b + return new_str def _fix_a_slash_b(string: str) -> str: @@ -131,10 +124,9 @@ def _fix_a_slash_b(string: str) -> str: try: a = int(a_str) b = int(b_str) - if not string == "{}/{}".format(a, b): + if string != f"{a}/{b}": raise AssertionError - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string + return "\\frac{" + str(a) + "}{" + str(b) + "}" except Exception: return string @@ -144,13 +136,12 @@ def _remove_right_units(string: str) -> str: Remove units (on the right). "\\text{ " only ever occurs (at least in the val set) when describing units. """ - if "\\text{ " in string: - splits = string.split("\\text{ ") - if not len(splits) == 2: - raise AssertionError - return splits[0] - else: + if "\\text{ " not in string: return string + splits = string.split("\\text{ ") + if len(splits) != 2: + raise AssertionError + return splits[0] def _fix_sqrt(string: str) -> str: @@ -221,7 +212,7 @@ def _strip_string(string: str) -> str: if len(string) == 0: return string if string[0] == ".": - string = "0" + string + string = f"0{string}" # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: @@ -256,9 +247,7 @@ def get_answer(solution: Optional[str]) -> Optional[str]: if last_boxed is None: return None answer = remove_boxed(last_boxed) - if answer is None: - return None - return answer + return None if answer is None else answer def is_equiv(str1: Optional[str], str2: Optional[str]) -> float: @@ -342,7 +331,7 @@ def eval_math_responses(responses, solution=None, **args): success_vote = is_equiv_chain_of_thought(responses[answer], solution) return { "expected_success": 1 - pow(1 - sum(success_list) / n, n), - "success": any(s for s in success_list), + "success": any(success_list), "success_vote": success_vote, "voted_answer": responses[answer], "votes": votes, diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py index a720ccc24466..1b2cb950cf51 100644 --- a/autogen/oai/completion.py +++ b/autogen/oai/completion.py @@ -327,8 +327,7 @@ def _get_params_for_create(cls, config: Dict) -> Dict: params["messages"] = cls._messages[config["messages"]] if "stop" in params: params["stop"] = cls._stops and cls._stops[params["stop"]] - temperature_or_top_p = params.pop("temperature_or_top_p", None) - if temperature_or_top_p: + if temperature_or_top_p := params.pop("temperature_or_top_p", None): params.update(temperature_or_top_p) if cls._config_list and "config_list" not in params: params["config_list"] = cls._config_list @@ -585,7 +584,7 @@ def eval_func(responses, **data): space["temperature_or_top_p"] = {"temperature": temperature} elif temperature is None and top_p is not None: space["temperature_or_top_p"] = {"top_p": top_p} - elif temperature is not None and top_p is not None: + elif temperature is not None: space.pop("temperature_or_top_p") space["temperature"] = temperature space["top_p"] = top_p @@ -648,9 +647,7 @@ def eval_func(responses, **data): subspace["best_of"] = space.pop("best_of") if "n" in space: subspace["n"] = space.pop("n") - choices = [] - for model in space["model"]: - choices.append({"model": model, **subspace}) + choices = [{"model": model, **subspace} for model in space["model"]] space["subspace"] = tune.choice(choices) space.pop("model") # start all the models with the same hp config diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index cbae458c59ca..1b8b2f4a25e8 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -126,8 +126,7 @@ def config_list_openai_aoai( if exclude != "openai" else [] ) - config_list = openai_config + aoai_config - return config_list + return openai_config + aoai_config def config_list_from_models( @@ -233,8 +232,7 @@ def config_list_from_json( Returns: list: A list of configs for openai api calls. """ - json_str = os.environ.get(env_or_file) - if json_str: + if json_str := os.environ.get(env_or_file): config_list = json.loads(json_str) else: config_list_path = os.path.join(file_location, env_or_file) diff --git a/autogen/retrieve_utils.py b/autogen/retrieve_utils.py index bc4fdfb75976..cfdf7e7ae88a 100644 --- a/autogen/retrieve_utils.py +++ b/autogen/retrieve_utils.py @@ -143,7 +143,7 @@ def split_files_to_chunks( def get_files_from_dir(dir_path: Union[str, List[str]], types: list = TEXT_FORMATS, recursive: bool = True): """Return a list of all the files in a given directory.""" - if len(types) == 0: + if not types: raise ValueError("types cannot be empty.") types = [t[1:].lower() if t.startswith(".") else t.lower() for t in set(types)] types += [t.upper() for t in types] @@ -323,10 +323,10 @@ class QueryResult(TypedDict): ef.SentenceTransformerEmbeddingFunction(embedding_model) if embedding_function is None else embedding_function ) query_embeddings = embedding_function(query_texts) - # Query/search n most similar results. You can also .get by id - results = collection.query( + return collection.query( query_embeddings=query_embeddings, n_results=n_results, - where_document={"$contains": search_string} if search_string else None, # optional filter + where_document={"$contains": search_string} + if search_string + else None, # optional filter ) - return results diff --git a/autogen/token_count_utils.py b/autogen/token_count_utils.py index fd9d61a10a16..86a226bd1eb3 100644 --- a/autogen/token_count_utils.py +++ b/autogen/token_count_utils.py @@ -55,7 +55,7 @@ def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613" """ if isinstance(input, str): return _num_token_from_text(input, model=model) - elif isinstance(input, list) or isinstance(input, dict): + elif isinstance(input, (list, dict)): return _num_token_from_messages(input, model=model) else: raise ValueError("input must be str, list or dict") @@ -149,8 +149,9 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int: num_tokens = 0 for function in functions: - function_tokens = len(encoding.encode(function["name"])) - function_tokens += len(encoding.encode(function["description"])) + function_tokens = len(encoding.encode(function["name"])) + len( + encoding.encode(function["description"]) + ) function_tokens -= 2 if "parameters" in function: parameters = function["parameters"] @@ -159,10 +160,7 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int: function_tokens += len(encoding.encode(propertiesKey)) v = parameters["properties"][propertiesKey] for field in v: - if field == "type": - function_tokens += 2 - function_tokens += len(encoding.encode(v["type"])) - elif field == "description": + if field == "description": function_tokens += 2 function_tokens += len(encoding.encode(v["description"])) elif field == "enum": @@ -170,6 +168,9 @@ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int: for o in v["enum"]: function_tokens += 3 function_tokens += len(encoding.encode(o)) + elif field == "type": + function_tokens += 2 + function_tokens += len(encoding.encode(v["type"])) else: print(f"Warning: not supported field {field}") function_tokens += 11 diff --git a/test/agentchat/chat_with_teachable_agent.py b/test/agentchat/chat_with_teachable_agent.py index 211ebe590975..b76c081de6c9 100644 --- a/test/agentchat/chat_with_teachable_agent.py +++ b/test/agentchat/chat_with_teachable_agent.py @@ -24,9 +24,13 @@ def create_teachable_agent(reset_db=False): # See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints # and OAI_CONFIG_LIST_sample config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST", filter_dict=filter_dict) - teachable_agent = TeachableAgent( + return TeachableAgent( name="teachableagent", - llm_config={"config_list": config_list, "request_timeout": 120, "use_cache": use_cache}, + llm_config={ + "config_list": config_list, + "request_timeout": 120, + "use_cache": use_cache, + }, teach_config={ "verbosity": verbosity, "reset_db": reset_db, @@ -34,7 +38,6 @@ def create_teachable_agent(reset_db=False): "recall_threshold": recall_threshold, }, ) - return teachable_agent def interact_freely_with_user(): diff --git a/test/agentchat/extensions/tsp.py b/test/agentchat/extensions/tsp.py index b979d407e35f..a6e7e03ab1ed 100644 --- a/test/agentchat/extensions/tsp.py +++ b/test/agentchat/extensions/tsp.py @@ -21,7 +21,7 @@ def solve_tsp(dists: dict) -> float: """ # Get the unique nodes from the distance matrix nodes = set() - for pair in dists.keys(): + for pair in dists: nodes.add(pair[0]) nodes.add(pair[1]) @@ -34,13 +34,10 @@ def solve_tsp(dists: dict) -> float: # Iterate through all possible routes for route in routes: - cost = 0 - # Calculate the cost of the current route - for i in range(len(route)): - current_node = route[i] - next_node = route[(i + 1) % len(route)] - cost += dists[(current_node, next_node)] - + cost = sum( + dists[route[i], route[(i + 1) % len(route)]] + for i in range(len(route)) + ) # Update the optimal cost if the current cost is smaller if cost < optimal_cost: optimal_cost = cost diff --git a/test/agentchat/test_async.py b/test/agentchat/test_async.py index 9a806e6af40f..e040c87c47a5 100644 --- a/test/agentchat/test_async.py +++ b/test/agentchat/test_async.py @@ -35,13 +35,12 @@ def get_market_news(ind, ind_upper): ] } feeds = data["feed"][ind:ind_upper] - feeds_summary = "\n".join( + return "\n".join( [ f"News summary: {f['title']}. {f['summary']} overall_sentiment_score: {f['overall_sentiment_score']}" for f in feeds ] ) - return feeds_summary async def test_stream(): diff --git a/test/agentchat/test_teachable_agent.py b/test/agentchat/test_teachable_agent.py index 7a3367dbd72c..e863da000263 100644 --- a/test/agentchat/test_teachable_agent.py +++ b/test/agentchat/test_teachable_agent.py @@ -39,9 +39,13 @@ def create_teachable_agent(reset_db=False, verbosity=0): # See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints # and OAI_CONFIG_LIST_sample config_list = config_list_from_json(env_or_file="OAI_CONFIG_LIST", filter_dict=filter_dict) - teachable_agent = TeachableAgent( + return TeachableAgent( name="teachableagent", - llm_config={"config_list": config_list, "request_timeout": 120, "use_cache": use_cache}, + llm_config={ + "config_list": config_list, + "request_timeout": 120, + "use_cache": use_cache, + }, teach_config={ "verbosity": verbosity, "reset_db": reset_db, @@ -49,7 +53,6 @@ def create_teachable_agent(reset_db=False, verbosity=0): "recall_threshold": recall_threshold, }, ) - return teachable_agent def check_agent_response(teachable_agent, user, correct_answer): diff --git a/test/oai/test_completion.py b/test/oai/test_completion.py index b6cb5c31b1c2..65a375a858aa 100644 --- a/test/oai/test_completion.py +++ b/test/oai/test_completion.py @@ -366,7 +366,7 @@ def test_math(num_samples=-1): ] print( "max tokens in tuning data's canonical solutions", - max([len(x["solution"].split()) for x in tune_data]), + max(len(x["solution"].split()) for x in tune_data), ) print(len(tune_data), len(test_data)) # prompt template @@ -385,7 +385,7 @@ def test_math(num_samples=-1): "prompt": prompts[0], "stop": "###", } - test_data_sample = test_data[0:3] + test_data_sample = test_data[:3] result = autogen.Completion.test(test_data_sample, eval_math_responses, config_list=config_list, **vanilla_config) result = autogen.Completion.test( test_data_sample,