diff --git a/elm/tree.py b/elm/tree.py index d2ea49d..ecb28ab 100644 --- a/elm/tree.py +++ b/elm/tree.py @@ -50,11 +50,16 @@ def __init__(self, graph): Directed acyclic graph where nodes are LLM prompts and edges are logical transitions based on the response. Must have high-level graph attribute "api" which is an ApiBase instance. Nodes should - have attribute "prompt" which can have {format} named arguments - that will be filled from the high-level graph attributes. Edges can - have attribute "condition" that is a callable to be executed on the - LLM response text. An edge from a node without a condition acts as - an "else" statement if no other edge conditions are satisfied. A + have attribute "prompt" which is a string that can have {format} + named arguments that will be filled from the high-level graph + attributes. Nodes can also have "callback" attributes that are + callables that act on the LLM response in an arbitrary way. The + function signature for a callback must be + ``callback(llm_response, decision_tree, node_name)``. + Edges can have attribute "condition" that is a callable to be + executed on the LLM response text that determines the edge + transition. An edge from a node without a condition acts as an + "else" statement if no other edge conditions are satisfied. A single edge from node to node does not need a condition. """ self._g = graph @@ -62,6 +67,19 @@ def __init__(self, graph): assert isinstance(self.graph, nx.DiGraph) assert 'api' in self.graph.graph + def __getitem__(self, key): + """Retrieve a node by name (str) or edge by (node0, node1) tuple""" + out = None + if key in self.graph.nodes: + out = self.graph.nodes[key] + elif key in self.graph.edges: + out = self.graph.edges[key] + else: + msg = (f'Could not find "{key}" in graph') + logger.error(msg) + raise KeyError(msg) + return out + @property def api(self): """Get the ApiBase object. @@ -112,13 +130,13 @@ def graph(self): """ return self._g - def call_node(self, node0): + def call_node(self, node_name): """Call the LLM with the prompt from the input node and search the successor edges for a valid transition condition Parameters ---------- - node0 : str + node_name : str Name of node being executed. Returns @@ -126,22 +144,30 @@ def call_node(self, node0): out : str Next node or LLM response if at a leaf node. """ - prompt = self._prepare_graph_call(node0) + + node = self[node_name] + prompt = self._prepare_graph_call(node_name) out = self.api.chat(prompt) - return self._parse_graph_output(node0, out) + node['response'] = out + + if 'callback' in node: + callback = node['callback'] + callback(out, self, node_name) + + return self._parse_graph_output(node_name, out) - def _prepare_graph_call(self, node0): + def _prepare_graph_call(self, node_name): """Prepare a graph call for given node.""" - prompt = self.graph.nodes[node0]['prompt'] + prompt = self[node_name]['prompt'] txt_fmt = {k: v for k, v in self.graph.graph.items() if k != 'api'} prompt = prompt.format(**txt_fmt) - self._history.append(node0) + self._history.append(node_name) return prompt def _parse_graph_output(self, node0, out): """Parse graph output for given node and LLM call output. """ successors = list(self.graph.successors(node0)) - edges = [self.graph.edges[(node0, node1)] for node1 in successors] + edges = [self[(node0, node1)] for node1 in successors] conditions = [edge.get('condition', None) for edge in edges] if len(successors) == 0: diff --git a/tests/test_tree.py b/tests/test_tree.py index 87c7d56..c6e679a 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -39,12 +39,20 @@ def test_chunk_and_embed(mocker): graph = nx.DiGraph(text='hello', name='Grant', api=ApiBase(model='gpt-35-turbo')) + response_dict = {} + + # pylint: disable=unused-argument + def callback(response, graph, node_name): + response_dict.update({node_name: response}) + graph.add_node('init', prompt='Say {text} to {name}') graph.add_edge('init', 'next', condition=lambda x: 'Grant' in x) - graph.add_node('next', prompt='How are you?') + graph.add_node('next', prompt='How are you?', callback=callback) tree = DecisionTree(graph) tree.run() assert 'init' in tree.history assert 'next' in tree.history + assert isinstance(tree['next']['response'], str) + assert isinstance(response_dict['next'], str)