diff --git a/contrib/pyln-client/pyln/client/gossmap.py b/contrib/pyln-client/pyln/client/gossmap.py index 50f5109ee788..0d4f1ea93779 100755 --- a/contrib/pyln-client/pyln/client/gossmap.py +++ b/contrib/pyln-client/pyln/client/gossmap.py @@ -329,6 +329,25 @@ def get_node(self, node_id: Union[GossmapNodeId, str]): node_id = GossmapNodeId.from_str(node_id) return self.nodes.get(node_id) + def get_nodes(self, + source: Union[GossmapNodeId, str], + depth: int = 0, excludes: set = None): + """ Returns a set of nodes within a given depth from a source node """ + node = self.get_node(source) + assert node is not None, f"Unknown source: {source}" + if excludes is None: + excludes = set() + excludes.add(node) + result = set() + result.add(node) + if depth > 0: + for channel in node.channels: + other = channel.node1 if channel.node1 != node else channel.node2 + if other in excludes: + continue + result.update(self.get_nodes(other.node_id, depth - 1, excludes)) + return result + def _update_channel(self, rec: bytes, off: int): fields = channel_update.read(io.BytesIO(rec[2:]), {}) direction = fields['channel_flags'] & 1 diff --git a/contrib/pyln-client/tests/test_gossmap.py b/contrib/pyln-client/tests/test_gossmap.py index e40262d0574a..6d4291e3c731 100644 --- a/contrib/pyln-client/tests/test_gossmap.py +++ b/contrib/pyln-client/tests/test_gossmap.py @@ -159,6 +159,8 @@ def test_mesh(tmp_path): scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56, scid58, scid69, scid78, scid89] + nodes = [*map(lambda nodeid: g.get_node(nodeid), nodeids)] + # check all nodes are there for nodeid in nodeids: node = g.get_node(nodeid) @@ -176,7 +178,25 @@ def test_mesh(tmp_path): assert channel.half_channels[1] # check basic relations - # l5 in the middle has 4 channels to l2, l4, l6 and l8 + # get_nodes l5 in the middle + result = g.get_nodes(source=nodeids[4]) + assert len(result) == 1 + assert str(next(iter(result)).node_id) == nodeids[4] + result = g.get_nodes(source=nodeids[4], depth=1) + assert len(result) == 5 + # on depth=1 the cross l2, l4, l5, l6, l8 must be returned + assert nodes[1] in result + assert nodes[3] in result + assert nodes[4] in result + assert nodes[5] in result + assert nodes[7] in result + # on depth=2 all nodes must be returned + result = g.get_nodes(source=nodeids[4], depth=2) + assert len(result) == 9 + for node in nodes: + assert node in result + + # get_halfchannels l5 in the middle has 4 channels to l2, l4, l6 and l8 result = g.get_halfchannels(source=nodeids[4]) exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]] exp_scidds = [scid25 + 'x1', scid45 + 'x0', scid56 + 'x1', scid58 + 'x0']