Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added callback feature to decision tree with test #41

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions elm/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,36 @@ def __init__(self, graph):
Directed acyclic graph where nodes are LLM prompts and edges are
logical transitions based on the response. Must have high-level
graph attribute "api" which is an ApiBase instance. Nodes should
have attribute "prompt" which can have {format} named arguments
that will be filled from the high-level graph attributes. Edges can
have attribute "condition" that is a callable to be executed on the
LLM response text. An edge from a node without a condition acts as
an "else" statement if no other edge conditions are satisfied. A
have attribute "prompt" which is a string that can have {format}
named arguments that will be filled from the high-level graph
attributes. Nodes can also have "callback" attributes that are
callables that act on the LLM response in an arbitrary way. The
function signature for a callback must be
``callback(llm_response, decision_tree, node_name)``.
Edges can have attribute "condition" that is a callable to be
executed on the LLM response text that determines the edge
transition. An edge from a node without a condition acts as an
"else" statement if no other edge conditions are satisfied. A
single edge from node to node does not need a condition.
"""
self._g = graph
self._history = []
assert isinstance(self.graph, nx.DiGraph)
assert 'api' in self.graph.graph

def __getitem__(self, key):
"""Retrieve a node by name (str) or edge by (node0, node1) tuple"""
out = None
if key in self.graph.nodes:
out = self.graph.nodes[key]
elif key in self.graph.edges:
out = self.graph.edges[key]
else:
msg = (f'Could not find "{key}" in graph')
logger.error(msg)
raise KeyError(msg)
return out

@property
def api(self):
"""Get the ApiBase object.
Expand Down Expand Up @@ -112,36 +130,44 @@ def graph(self):
"""
return self._g

def call_node(self, node0):
def call_node(self, node_name):
"""Call the LLM with the prompt from the input node and search the
successor edges for a valid transition condition

Parameters
----------
node0 : str
node_name : str
Name of node being executed.

Returns
-------
out : str
Next node or LLM response if at a leaf node.
"""
prompt = self._prepare_graph_call(node0)

node = self[node_name]
prompt = self._prepare_graph_call(node_name)
out = self.api.chat(prompt)
return self._parse_graph_output(node0, out)
node['response'] = out

if 'callback' in node:
callback = node['callback']
callback(out, self, node_name)

return self._parse_graph_output(node_name, out)

def _prepare_graph_call(self, node0):
def _prepare_graph_call(self, node_name):
"""Prepare a graph call for given node."""
prompt = self.graph.nodes[node0]['prompt']
prompt = self[node_name]['prompt']
txt_fmt = {k: v for k, v in self.graph.graph.items() if k != 'api'}
prompt = prompt.format(**txt_fmt)
self._history.append(node0)
self._history.append(node_name)
return prompt

def _parse_graph_output(self, node0, out):
"""Parse graph output for given node and LLM call output. """
successors = list(self.graph.successors(node0))
edges = [self.graph.edges[(node0, node1)] for node1 in successors]
edges = [self[(node0, node1)] for node1 in successors]
conditions = [edge.get('condition', None) for edge in edges]

if len(successors) == 0:
Expand Down
10 changes: 9 additions & 1 deletion tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,20 @@ def test_chunk_and_embed(mocker):
graph = nx.DiGraph(text='hello', name='Grant',
api=ApiBase(model='gpt-35-turbo'))

response_dict = {}

# pylint: disable=unused-argument
def callback(response, graph, node_name):
response_dict.update({node_name: response})

graph.add_node('init', prompt='Say {text} to {name}')
graph.add_edge('init', 'next', condition=lambda x: 'Grant' in x)
graph.add_node('next', prompt='How are you?')
graph.add_node('next', prompt='How are you?', callback=callback)

tree = DecisionTree(graph)
tree.run()

assert 'init' in tree.history
assert 'next' in tree.history
assert isinstance(tree['next']['response'], str)
assert isinstance(response_dict['next'], str)
Loading