Skip to content

Commit

Permalink
pygossmap: adds get_halfchannels method
Browse files Browse the repository at this point in the history
  • Loading branch information
m-schmoock committed Feb 17, 2023
1 parent 0ade450 commit 5a2e92a
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
44 changes: 44 additions & 0 deletions contrib/pyln-client/pyln/client/gossmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,50 @@ def get_halfchannel(self,
channel = self.get_channel(short_channel_id)
return channel.half_channels[direction]

def get_halfchannels(self,
source: Union[GossmapNodeId, str] = None,
destination: Union[GossmapNodeId, str] = None,
depth: int = 0, excludes: set = None):
""" Returns a set[GossmapHalfchannel]` from `source` or towards
`destination` node ID. Using the optional `depth` greater than `0`
will result in a second, third, ... order list of connected
channels towards or from that node.
Note: only one of `source` or `destination` can be given. """
assert (source is None) ^ (destination is None), "Only one of source or destination must be given"
node = self.get_node(source if source else destination)
assert node is not None, "Source or Destination unknown"
assert depth >= 0, "Depth cannot be smaller than 1"
result = set()
if excludes is None:
excludes = set()
for channel in node.channels:
other = channel.node1 if channel.node1 != node else channel.node2
# recurse and merge results if depth not yet 0
if depth > 0:
hc0 = channel.half_channels[0]
hc1 = channel.half_channels[1]
if hc0 is not None:
excludes.add(hc0)
if hc1 is not None:
excludes.add(hc1)
if source is not None:
result.update(self.get_halfchannels(other.node_id, None,
depth - 1, excludes))
else:
result.update(self.get_halfchannels(None, other.node_id,
depth - 1, excludes))
else:
direction = 0
if source is not None and node > other:
direction = 1
if destination is not None and node < other:
direction = 1
hc = channel.half_channels[direction]
if hc is not None and hc not in excludes:
result.add(hc)
result.difference_update(excludes) # we may have added too early
return result

def get_node(self, node_id: Union[GossmapNodeId, str]):
""" Resolves a node by its public key node_id """
if isinstance(node_id, str):
Expand Down
56 changes: 56 additions & 0 deletions contrib/pyln-client/tests/test_gossmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,59 @@ def test_mesh(tmp_path):
assert str(channel.scid) == scid
assert channel.half_channels[0]
assert channel.half_channels[1]

# check basic relations
# l5 in the middle has 4 channels to l2, l4, l6 and l8
result = g.get_halfchannels(source=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + 'x1', scid45 + 'x0', scid56 + 'x1', scid58 + 'x0']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.source.node_id) == nodeids[4]
assert str(halfchan.destination.node_id) in exp_ids
assert str(halfchan) in exp_scidds

# same but other direction
result = g.get_halfchannels(destination=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + 'x0', scid45 + 'x1', scid56 + 'x0', scid58 + 'x1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[4]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds

# get all channels which have l1 as destination
result = g.get_halfchannels(destination=nodeids[0])
exp_ids = [nodeids[1], nodeids[3]]
exp_scidds = [scid12 + 'x0', scid14 + 'x1']
assert len(result) == len(exp_ids)
for halfchan in result:
assert str(halfchan.destination.node_id) == nodeids[0]
assert str(halfchan.source.node_id) in exp_ids
assert str(halfchan) in exp_scidds

# l5 as destination in the middle but depth 1, so the outer ring
# epxect 12, 14, 32, 36, 74, 78, 98, 96
result = g.get_halfchannels(destination=nodeids[4], depth=1)
exp_scidds = [scid12 + 'x1', scid14 + 'x0', scid23 + 'x1', scid36 + 'x1',
scid47 + 'x0', scid69 + 'x1', scid78 + 'x0', scid89 + 'x0']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds

# same but other direction
result = g.get_halfchannels(source=nodeids[4], depth=1)
exp_scidds = [scid12 + 'x0', scid14 + 'x1', scid23 + 'x0', scid36 + 'x0',
scid47 + 'x1', scid69 + 'x0', scid78 + 'x1', scid89 + 'x1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds

# l9 as destination and depth 2
# expect 23 25 45 47
result = g.get_halfchannels(destination=nodeids[8], depth=2)
exp_scidds = [scid23 + 'x0', scid25 + 'x0', scid45 + 'x1', scid47 + 'x1']
assert len(result) == len(exp_scidds)
for halfchan in result:
assert str(halfchan) in exp_scidds

0 comments on commit 5a2e92a

Please sign in to comment.