Skip to content

Commit

Permalink
rpc: populate allowed/not allowed node ids when listing orders
Browse files Browse the repository at this point in the history
  • Loading branch information
positiveblue committed Jan 5, 2023
1 parent 627c9c4 commit 6665f98
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 17 deletions.
61 changes: 47 additions & 14 deletions order/rpc_parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,19 @@ func ParseRPCOrder(version, leaseDuration uint32,
"at the same time")
}

kit.AllowedNodeIDs = make([][33]byte, len(details.AllowedNodeIds))
for idx, nodeID := range details.AllowedNodeIds {
if _, err := btcec.ParsePubKey(nodeID); err != nil {
return nil, fmt.Errorf("invalid allowed_node_id: %x",
nodeID)
}
copy(kit.AllowedNodeIDs[idx][:], nodeID)
allowedNodeIDs, err := UnmarshalNodeIDSlice(details.AllowedNodeIds)
if err != nil {
return nil, fmt.Errorf("invalid allowed_node_ids: %v", err)
}
kit.AllowedNodeIDs = allowedNodeIDs

kit.NotAllowedNodeIDs = make([][33]byte, len(details.NotAllowedNodeIds))
for idx, nodeID := range details.NotAllowedNodeIds {
if _, err := btcec.ParsePubKey(nodeID); err != nil {
return nil, fmt.Errorf("invalid not_allowed_node_id: "+
"%x", nodeID)
}
copy(kit.NotAllowedNodeIDs[idx][:], nodeID)
notAllowedNodeIDs, err := UnmarshalNodeIDSlice(
details.NotAllowedNodeIds,
)
if err != nil {
return nil, fmt.Errorf("invalid not_allowed_node_ids: %v", err)
}
kit.NotAllowedNodeIDs = notAllowedNodeIDs

kit.IsPublic = details.IsPublic

Expand Down Expand Up @@ -524,6 +520,43 @@ func ParseRPCSign(signMsg *auctioneerrpc.OrderMatchSignBegin) (AccountNonces,
return nonces, prevOutputs, nil
}

// MarshalNodeIDSlice returns a flatten version of an slice of node ids to be
// used in rpc serialization.
func MarshalNodeIDSlice(nodeIDs [][33]byte) [][]byte {
res := make([][]byte, 0, len(nodeIDs))

for i := range nodeIDs {
nodeID := make([]byte, 33)
copy(nodeID, nodeIDs[i][:])

res = append(res, nodeID)
}

return res
}

// UnmarshalNodeIDSlice returns a slice of node ids from a flatten version.
func UnmarshalNodeIDSlice(slice [][]byte) ([][33]byte, error) {
nodeIDs := make([][33]byte, len(slice))
for idx := range slice {
// Check that the node id pub key is in the correct format.
if len(slice[idx]) != 33 {
return nil, fmt.Errorf("invalid node_id length: %x",
slice[idx])
}

// Check that the node id pub key is a valid key.
if _, err := btcec.ParsePubKey(slice[idx]); err != nil {
return nil, fmt.Errorf("invalid node_id: %x",
slice[idx])
}

copy(nodeIDs[idx][:], slice[idx])
}

return nodeIDs, nil
}

// randomPreimage creates a new preimage from a random number generator.
func randomPreimage() ([]byte, error) {
var nonce Nonce
Expand Down
94 changes: 94 additions & 0 deletions order/rpc_parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package order

import (
"encoding/hex"
"testing"

"github.com/stretchr/testify/require"
)

var nodeIDSerializationTestCases = []struct {
name string
nodeIDs func() [][33]byte
invalidSerializedData func() [][]byte
expectedErr string
}{{
name: "empty slice",
nodeIDs: func() [][33]byte {
return [][33]byte{}
},
}, {
name: "single node id",
nodeIDs: func() [][33]byte {
return [][33]byte{
nodePubkey,
}
},
}, {
name: "multiple node ids",
nodeIDs: func() [][33]byte {
nodeID, _ := hex.DecodeString("036b51e0cc2d9e5988ee4967e0ba67" +
"ef3727bb633fea21a0af58e0c9395446ba09")
var nodePubKey2 [33]byte
copy(nodePubKey2[:], nodeID)

return [][33]byte{
nodePubkey,
nodePubKey2,
}
},
}, {
name: "invalid length",
invalidSerializedData: func() [][]byte {
return [][]byte{
{1, 2},
}
},
expectedErr: "invalid node_id length",
}, {
name: "invalid pub key",
invalidSerializedData: func() [][]byte {
return MarshalNodeIDSlice([][33]byte{
{1, 2},
})
},
expectedErr: "invalid node_id:",
}}

// TestNodeIDSliceSerialization tests that we can properly serialize and
// deserialize a slice of node ids.
func TestNodeIDSliceSerialization(t *testing.T) {
for _, tc := range nodeIDSerializationTestCases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

switch {
// Marshal and Unmarshal valid node ids.
case tc.nodeIDs != nil:
nodeIDs := tc.nodeIDs()
marshaled := MarshalNodeIDSlice(nodeIDs)
require.Equal(t, len(nodeIDs), len(marshaled))

unmarshaled, err := UnmarshalNodeIDSlice(
marshaled,
)

require.NoError(t, err)
require.Equal(t, tc.nodeIDs(), unmarshaled)

// Unmarshal invalid marshaled node ids.
case tc.invalidSerializedData != nil:
marshaled := tc.invalidSerializedData()

_, err := UnmarshalNodeIDSlice(marshaled)
require.Error(t, err)
require.Contains(t, err.Error(), tc.expectedErr)

default:
require.Fail(t, "invalid test case")
}
})
}
}
15 changes: 12 additions & 3 deletions rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,14 @@ func (s *rpcServer) ListOrders(ctx context.Context,
}
}

allowedNodeIDs := order.MarshalNodeIDSlice(
dbOrder.Details().AllowedNodeIDs,
)

notAllowedNodeIDs := order.MarshalNodeIDSlice(
dbOrder.Details().NotAllowedNodeIDs,
)

details := &poolrpc.Order{
TraderKey: dbDetails.AcctKey[:],
RateFixed: dbDetails.FixedRate,
Expand All @@ -1721,7 +1729,9 @@ func (s *rpcServer) ListOrders(ctx context.Context,
AuctionType: auctioneerrpc.AuctionType(
dbOrder.Details().AuctionType,
),
IsPublic: dbOrder.Details().IsPublic,
AllowedNodeIds: allowedNodeIDs,
NotAllowedNodeIds: notAllowedNodeIDs,
IsPublic: dbOrder.Details().IsPublic,
}

switch o := dbOrder.(type) {
Expand Down Expand Up @@ -1778,8 +1788,7 @@ func (s *rpcServer) ListOrders(ctx context.Context,
bids = append(bids, rpcBid)

default:
return nil, fmt.Errorf("unknown order type: %v",
o)
return nil, fmt.Errorf("unknown order type: %v", o)
}
}
return &poolrpc.ListOrdersResponse{
Expand Down

0 comments on commit 6665f98

Please sign in to comment.