Skip to content

Commit

Permalink
implement two stage plugin selection
Browse files Browse the repository at this point in the history
  • Loading branch information
binary-sky committed Aug 29, 2023
1 parent f40d48b commit eb802ee
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions crazy_functions/vt_fns/vt_call_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from crazy_functions.json_fns.pydantic_io import GptJsonIO
import copy, json, pickle, os, sys


def read_avail_plugin_enum():
from crazy_functional import get_crazy_functions
plugin_arr = get_crazy_functions()
Expand All @@ -16,6 +17,7 @@ def read_avail_plugin_enum():
prompt = "\n\nThe defination of PluginEnum:\nPluginEnum=" + prompt
return prompt, plugin_arr_dict


def execute_plugin(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_intention):
plugin_arr_enum_prompt, plugin_arr_dict = read_avail_plugin_enum()
class Plugin(BaseModel):
Expand All @@ -33,16 +35,32 @@ class Plugin(BaseModel):
inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
plugin_sel = gpt_json_io.generate_output_auto_repair(run_gpt_fn(inputs, ""), run_gpt_fn)

if plugin_sel.plugin_selection in plugin_arr_dict:
# ⭐ ⭐ ⭐ 执行插件
plugin = plugin_arr_dict[plugin_sel.plugin_selection]
fn = plugin['Function']
fn_name = fn.__name__
msg = f'正在调用插件: {fn_name}\n\n插件说明:{plugin["Info"]}\n\n插件参数:{plugin_sel.plugin_arg}'
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=2)
yield from fn(plugin_sel.plugin_arg, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, -1)
return
else:
if plugin_sel.plugin_selection not in plugin_arr_dict:
msg = f'找不到合适插件执行该任务'
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=2)
return
return

# ⭐ ⭐ ⭐ 确认插件参数
plugin = plugin_arr_dict[plugin_sel.plugin_selection]
yield from update_ui_lastest_msg(lastmsg=f"正在执行任务: {txt}\n\n提取插件参数...", chatbot=chatbot, history=history, delay=0)
class PluginExplicit(BaseModel):
plugin_selection: str = plugin_sel.plugin_selection
plugin_arg: str = Field(description="The argument of the plugin.", default="")
gpt_json_io = GptJsonIO(PluginExplicit)
gpt_json_io.format_instructions += "The information about this plugin is:" + plugin["Info"]
inputs = f"A plugin named {plugin_sel.plugin_selection} is selected, " + \
"you should extract plugin_arg from the user requirement, the user requirement is: \n\n" + \
">> " + txt.rstrip('\n').replace('\n','\n>> ') + '\n\n' + \
gpt_json_io.format_instructions
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(
inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
plugin_sel = gpt_json_io.generate_output_auto_repair(run_gpt_fn(inputs, ""), run_gpt_fn)


# ⭐ ⭐ ⭐ 执行插件
fn = plugin['Function']
fn_name = fn.__name__
msg = f'正在调用插件: {fn_name}\n\n插件说明:{plugin["Info"]}\n\n插件参数:{plugin_sel.plugin_arg}'
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=2)
yield from fn(plugin_sel.plugin_arg, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, -1)
return

0 comments on commit eb802ee

Please sign in to comment.