From 99ee3a4cbd68b3ad5c464ecb108bd5278d08aa48 Mon Sep 17 00:00:00 2001 From: Michael Schmoock Date: Tue, 14 Feb 2023 22:15:15 +0100 Subject: [PATCH] pygossmap: adds get_halfchannels method --- contrib/pyln-client/pyln/client/gossmap.py | 45 +++++++++++++++++ contrib/pyln-client/tests/test_gossmap.py | 56 ++++++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/contrib/pyln-client/pyln/client/gossmap.py b/contrib/pyln-client/pyln/client/gossmap.py index d1aaa060bc2c..8d03f662b2a7 100755 --- a/contrib/pyln-client/pyln/client/gossmap.py +++ b/contrib/pyln-client/pyln/client/gossmap.py @@ -274,6 +274,51 @@ def get_halfchannel(self, channel = self.get_channel(short_channel_id) return channel.half_channels[direction] + def get_halfchannels(self, + source: Union[GossmapNodeId, str] = None, + destination: Union[GossmapNodeId, str] = None, + depth: int = 0, excludes: set = None): + + """ Returns a set[GossmapHalfchannel]` from `source` or towards + `destination` node ID. Using the optional `depth` greater than `0` + will result in a second, third, ... order list of connected + channels towards or from that node. + Note: only one of `source` or `destination` can be given. """ + assert (source is None) ^ (destination is None), "Only one of source or destination must be given" + node = self.get_node(source if source else destination) + assert node is not None, "Source or Destination unknown" + assert depth >= 0, "Depth cannot be smaller than 1" + result = set() + if excludes is None: + excludes = set() + for channel in node.channels: + other = channel.node1 if channel.node1 != node else channel.node2 + # recurse and merge results if depth not yet 0 + if depth > 0: + hc0 = channel.half_channels[0] + hc1 = channel.half_channels[1] + if hc0 is not None: + excludes.add(hc0) + if hc1 is not None: + excludes.add(hc1) + if source is not None: + result.update(self.get_halfchannels(other.node_id, None, + depth - 1, excludes)) + else: + result.update(self.get_halfchannels(None, other.node_id, + depth - 1, excludes)) + else: + direction = 0 + if source is not None and node > other: + direction = 1 + if destination is not None and node < other: + direction = 1 + hc = channel.half_channels[direction] + if hc is not None and hc not in excludes: + result.add(hc) + result.difference_update(excludes) # we may have added too early + return result + def get_node(self, node_id: Union[GossmapNodeId, str]): """ Resolves a node by its public key node_id """ if isinstance(node_id, str): diff --git a/contrib/pyln-client/tests/test_gossmap.py b/contrib/pyln-client/tests/test_gossmap.py index 0daa1b22a512..e40262d0574a 100644 --- a/contrib/pyln-client/tests/test_gossmap.py +++ b/contrib/pyln-client/tests/test_gossmap.py @@ -174,3 +174,59 @@ def test_mesh(tmp_path): assert str(channel.scid) == scid assert channel.half_channels[0] assert channel.half_channels[1] + + # check basic relations + # l5 in the middle has 4 channels to l2, l4, l6 and l8 + result = g.get_halfchannels(source=nodeids[4]) + exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]] + exp_scidds = [scid25 + 'x1', scid45 + 'x0', scid56 + 'x1', scid58 + 'x0'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.source.node_id) == nodeids[4] + assert str(halfchan.destination.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # same but other direction + result = g.get_halfchannels(destination=nodeids[4]) + exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]] + exp_scidds = [scid25 + 'x0', scid45 + 'x1', scid56 + 'x0', scid58 + 'x1'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.destination.node_id) == nodeids[4] + assert str(halfchan.source.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # get all channels which have l1 as destination + result = g.get_halfchannels(destination=nodeids[0]) + exp_ids = [nodeids[1], nodeids[3]] + exp_scidds = [scid12 + 'x0', scid14 + 'x1'] + assert len(result) == len(exp_ids) + for halfchan in result: + assert str(halfchan.destination.node_id) == nodeids[0] + assert str(halfchan.source.node_id) in exp_ids + assert str(halfchan) in exp_scidds + + # l5 as destination in the middle but depth 1, so the outer ring + # epxect 12, 14, 32, 36, 74, 78, 98, 96 + result = g.get_halfchannels(destination=nodeids[4], depth=1) + exp_scidds = [scid12 + 'x1', scid14 + 'x0', scid23 + 'x1', scid36 + 'x1', + scid47 + 'x0', scid69 + 'x1', scid78 + 'x0', scid89 + 'x0'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # same but other direction + result = g.get_halfchannels(source=nodeids[4], depth=1) + exp_scidds = [scid12 + 'x0', scid14 + 'x1', scid23 + 'x0', scid36 + 'x0', + scid47 + 'x1', scid69 + 'x0', scid78 + 'x1', scid89 + 'x1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds + + # l9 as destination and depth 2 + # expect 23 25 45 47 + result = g.get_halfchannels(destination=nodeids[8], depth=2) + exp_scidds = [scid23 + 'x0', scid25 + 'x0', scid45 + 'x1', scid47 + 'x1'] + assert len(result) == len(exp_scidds) + for halfchan in result: + assert str(halfchan) in exp_scidds