Skip to content

Commit

Permalink
pygossmap: adds get_nodes flodding method
Browse files Browse the repository at this point in the history
  • Loading branch information
m-schmoock committed Feb 17, 2023
1 parent 5a2e92a commit 74b8395
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
19 changes: 19 additions & 0 deletions contrib/pyln-client/pyln/client/gossmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,25 @@ def get_node(self, node_id: Union[GossmapNodeId, str]):
node_id = GossmapNodeId.from_str(node_id)
return self.nodes.get(node_id)

def get_nodes(self,
source: Union[GossmapNodeId, str],
depth: int = 0, excludes: set = None):
""" Returns a set of nodes within a given depth from a source node """
node = self.get_node(source)
assert node is not None, f"Unknown source: {source}"
if excludes is None:
excludes = set()
excludes.add(node)
result = set()
result.add(node)
if depth > 0:
for channel in node.channels:
other = channel.node1 if channel.node1 != node else channel.node2
if other in excludes:
continue
result.update(self.get_nodes(other.node_id, depth - 1, excludes))
return result

def _update_channel(self, rec: bytes, off: int):
fields = channel_update.read(io.BytesIO(rec[2:]), {})
direction = fields['channel_flags'] & 1
Expand Down
22 changes: 21 additions & 1 deletion contrib/pyln-client/tests/test_gossmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def test_mesh(tmp_path):
scids = [scid12, scid14, scid23, scid25, scid36, scid45, scid47, scid56,
scid58, scid69, scid78, scid89]

nodes = [*map(lambda nodeid: g.get_node(nodeid), nodeids)]

# check all nodes are there
for nodeid in nodeids:
node = g.get_node(nodeid)
Expand All @@ -176,7 +178,25 @@ def test_mesh(tmp_path):
assert channel.half_channels[1]

# check basic relations
# l5 in the middle has 4 channels to l2, l4, l6 and l8
# get_nodes l5 in the middle
result = g.get_nodes(source=nodeids[4])
assert len(result) == 1
assert str(next(iter(result)).node_id) == nodeids[4]
result = g.get_nodes(source=nodeids[4], depth=1)
assert len(result) == 5
# on depth=1 the cross l2, l4, l5, l6, l8 must be returned
assert nodes[1] in result
assert nodes[3] in result
assert nodes[4] in result
assert nodes[5] in result
assert nodes[7] in result
# on depth=2 all nodes must be returned
result = g.get_nodes(source=nodeids[4], depth=2)
assert len(result) == 9
for node in nodes:
assert node in result

# get_halfchannels l5 in the middle has 4 channels to l2, l4, l6 and l8
result = g.get_halfchannels(source=nodeids[4])
exp_ids = [nodeids[1], nodeids[3], nodeids[5], nodeids[7]]
exp_scidds = [scid25 + 'x1', scid45 + 'x0', scid56 + 'x1', scid58 + 'x0']
Expand Down

0 comments on commit 74b8395

Please sign in to comment.