diff --git a/docs/release-notes/release-notes-0.17.0.md b/docs/release-notes/release-notes-0.17.0.md index 9040ae74ea..2c04c22030 100644 --- a/docs/release-notes/release-notes-0.17.0.md +++ b/docs/release-notes/release-notes-0.17.0.md @@ -71,10 +71,14 @@ fails](https://github.com/lightningnetwork/lnd/pull/7876). retried](https://github.com/lightningnetwork/lnd/pull/7927) with an exponential back off. + +* In the watchtower client, we [now explicitly + handle](https://github.com/lightningnetwork/lnd/pull/7981) the scenario where + a channel is closed while we still have an in-memory update for it. + * `lnd` [now properly handles a case where an erroneous force close attempt would impeded start up](https://github.com/lightningnetwork/lnd/pull/7985). - # New Features ## Functional Enhancements diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 1f45335fb9..228877743a 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -390,6 +390,10 @@ func constructFunctionalOptions(includeSessions, return opts, ackCounts, committedUpdateCounts } + perNumRogueUpdates := func(s *wtdb.ClientSession, numUpdates uint16) { + ackCounts[s.ID] += numUpdates + } + perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID, numUpdates uint16) { @@ -405,6 +409,7 @@ func constructFunctionalOptions(includeSessions, opts = []wtdb.ClientSessionListOption{ wtdb.WithPerNumAckedUpdates(perNumAckedUpdates), wtdb.WithPerCommittedUpdate(perCommittedUpdate), + wtdb.WithPerRogueUpdateCount(perNumRogueUpdates), } if excludeExhaustedSessions { diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f3b8d3307b..412412c1e3 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -977,6 +977,19 @@ func (c *TowerClient) handleClosableSessions( // and handle it. c.closableSessionQueue.Pop() + // Stop the session and remove it from the + // in-memory set. + err := c.activeSessions.StopAndRemove( + item.sessionID, + ) + if err != nil { + c.log.Errorf("could not remove "+ + "session(%s) from in-memory "+ + "set: %v", item.sessionID, err) + + return + } + // Fetch the session from the DB so that we can // extract the Tower info. sess, err := c.cfg.DB.GetClientSession( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 5cb998c9a1..51dd16e096 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -72,7 +73,7 @@ var ( addrScript, _ = txscript.PayToAddrScript(addr) - waitTime = 5 * time.Second + waitTime = 15 * time.Second defaultTxPolicy = wtpolicy.TxPolicy{ BlobType: blob.TypeAltruistCommit, @@ -398,7 +399,7 @@ type testHarness struct { cfg harnessCfg signer *wtmock.MockSigner capacity lnwire.MilliSatoshi - clientDB *wtmock.ClientDB + clientDB *wtdb.ClientDB clientCfg *wtclient.Config client wtclient.Client server *serverHarness @@ -426,10 +427,26 @@ type harnessCfg struct { noServerStart bool } +func newClientDB(t *testing.T) *wtdb.ClientDB { + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, + } + + // Construct the ClientDB. + dir := t.TempDir() + bdb, err := wtdb.NewBoltBackendCreator(true, dir, "wtclient.db")(dbCfg) + require.NoError(t, err) + + clientDB, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + return clientDB +} + func newHarness(t *testing.T, cfg harnessCfg) *testHarness { signer := wtmock.NewMockSigner() mockNet := newMockNet() - clientDB := wtmock.NewClientDB() + clientDB := newClientDB(t) server := newServerHarness( t, mockNet, towerAddrStr, func(serverCfg *wtserver.Config) { @@ -509,6 +526,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { h.startClient() t.Cleanup(func() { require.NoError(t, h.client.Stop()) + require.NoError(t, h.clientDB.Close()) }) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) @@ -1342,7 +1360,7 @@ var clientTests = []clientTest{ // Wait for all the updates to be populated in the // server's database. - h.server.waitForUpdates(hints, 3*time.Second) + h.server.waitForUpdates(hints, waitTime) }, }, { @@ -2053,7 +2071,7 @@ var clientTests = []clientTest{ // Now stop the client and reset its database. require.NoError(h.t, h.client.Stop()) - db := wtmock.NewClientDB() + db := newClientDB(h.t) h.clientDB = db h.clientCfg.DB = db @@ -2398,6 +2416,140 @@ var clientTests = []clientTest{ server2.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, + { + // This test shows that if a channel is closed while an update + // for that channel still exists in an in-memory queue + // somewhere then it is handled correctly by treating it as a + // rogue update. + name: "channel closed while update is un-acked", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 10 + chanIDInt = 0 + ) + + h.sendUpdatesOn = true + + // Advance the channel with a few updates. + hints := h.advanceChannelN(chanIDInt, numUpdates) + + // Backup a few these updates and wait for them to + // arrive at the server. Note that we back up enough + // updates to saturate the session so that the session + // is considered closable when the channel is deleted. + h.backupStates(chanIDInt, 0, numUpdates/2, nil) + h.server.waitForUpdates(hints[:numUpdates/2], waitTime) + + // Now, restart the server in a state where it will not + // ack updates. This will allow us to wait for an + // update to be un-acked and persisted. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = true + }) + + // Backup a few more of the update. These should remain + // in the client as un-acked. + h.backupStates( + chanIDInt, numUpdates/2, numUpdates-1, nil, + ) + + // Wait for the tasks to be bound to sessions. + fetchSessions := h.clientDB.FetchSessionCommittedUpdates + err := wait.Predicate(func() bool { + sessions, err := h.clientDB.ListClientSessions( + nil, + ) + require.NoError(h.t, err) + + var updates []wtdb.CommittedUpdate + for id := range sessions { + updates, err = fetchSessions(&id) + require.NoError(h.t, err) + + if len(updates) != numUpdates-1 { + return true + } + } + + return false + }, waitTime) + require.NoError(h.t, err) + + // Now we close this channel while the update for it has + // not yet been acked. + h.closeChannel(chanIDInt, 1) + + // Closable sessions should now be one. + err = wait.Predicate(func() bool { + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + return len(cs) == 1 + }, waitTime) + require.NoError(h.t, err) + + // Now, restart the server and allow it to ack updates + // again. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = false + }) + + // Mine a few blocks so that the session close range is + // surpassed. + h.mine(3) + + // Wait for there to be no more closable sessions on the + // client side. + err = wait.Predicate(func() bool { + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + return len(cs) == 0 + }, waitTime) + require.NoError(h.t, err) + + // Wait for channel to be "unregistered". + chanID := chanIDFromInt(chanIDInt) + err = wait.Predicate(func() bool { + err := h.client.BackupState(&chanID, 0) + + return errors.Is( + err, wtclient.ErrUnregisteredChannel, + ) + }, waitTime) + require.NoError(h.t, err) + + // Show that the committed update for the closed channel + // is cleared from the DB. + err = wait.Predicate(func() bool { + sessions, err := h.clientDB.ListClientSessions( + nil, + ) + require.NoError(h.t, err) + + var updates []wtdb.CommittedUpdate + for id := range sessions { + updates, err = fetchSessions(&id) + require.NoError(h.t, err) + + if len(updates) != 0 { + return false + } + } + + return true + }, waitTime) + require.NoError(h.t, err) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup diff --git a/watchtower/wtclient/queue_test.go b/watchtower/wtclient/queue_test.go index 81f96bb7f6..529acb49a4 100644 --- a/watchtower/wtclient/queue_test.go +++ b/watchtower/wtclient/queue_test.go @@ -18,51 +18,13 @@ const ( waitTime = time.Second * 2 ) -type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID] - // TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected. func TestDiskOverflowQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init initQueue - }{ - { - name: "kvdb", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - db.Close() - }) - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - { - name: "mock", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - db := wtmock.NewClientDB() - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - } - tests := []struct { name string - run func(*testing.T, initQueue) + run func(*testing.T, wtdb.Queue[*wtdb.BackupID]) }{ { name: "overflow to disk", @@ -78,29 +40,43 @@ func TestDiskOverflowQueue(t *testing.T) { }, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + initDB := func() wtdb.Queue[*wtdb.BackupID] { + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - test.run(t, db.init) - }) - } + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) + + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + return db.GetDBQueue([]byte("test-namespace")) + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + + test.run(tt, initDB()) }) } } // testOverflowToDisk is a basic test that ensures that the queue correctly // overflows items to disk and then correctly reloads them. -func testOverflowToDisk(t *testing.T, initQueue initQueue) { +func testOverflowToDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Init the DB. - db := initQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -146,7 +122,9 @@ func testOverflowToDisk(t *testing.T, initQueue initQueue) { // testRestartWithSmallerBufferSize tests that if the queue is restarted with // a smaller in-memory buffer size that it was initially started with, then // tasks are still loaded in the correct order. -func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { +func testRestartWithSmallerBufferSize(t *testing.T, + db wtdb.Queue[*wtdb.BackupID]) { + const ( firstMaxInMemItems = 5 secondMaxInMemItems = 2 @@ -155,9 +133,6 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Create a db. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -223,14 +198,11 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // testStartStopQueue is a stress test that pushes a large number of tasks // through the queue while also restarting the queue a couple of times // throughout. -func testStartStopQueue(t *testing.T, newQueue initQueue) { +func testStartStopQueue(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate a lot of backup IDs that we want to add to the // queue one after the other. tasks := genBackupIDs(200_000) - // Construct the ClientDB. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 41d80587d7..084f2dcfe0 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -50,6 +50,7 @@ var ( // => cSessionDBID -> db-assigned-id // => cSessionCommits => seqnum -> encoded CommittedUpdate // => cSessionAckRangeIndex => db-chan-id => start -> end + // => cSessionRogueUpdateCount -> count cSessionBkt = []byte("client-session-bucket") // cSessionDBID is a key used in the cSessionBkt to store the @@ -68,6 +69,12 @@ var ( // chan-id => start -> end cSessionAckRangeIndex = []byte("client-session-ack-range-index") + // cSessionRogueUpdateCount is a key in the cSessionBkt bucket storing + // the number of rogue updates that were backed up using the session. + // Rogue updates are updates for channels that have been closed already + // at the time of the back-up. + cSessionRogueUpdateCount = []byte("client-session-rogue-update-count") + // cChanIDIndexBkt is a top-level bucket storing: // db-assigned-id -> channel-ID cChanIDIndexBkt = []byte("client-channel-id-index") @@ -980,29 +987,8 @@ func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) ( // getRangesWriteBucket gets the range index bucket where the range index for // the given session-channel pair is stored. If any sub-buckets along the way do // not exist, then they are created. -func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID, - chanID lnwire.ChannelID) (kvdb.RwBucket, error) { - - sessions := tx.ReadWriteBucket(cSessionBkt) - if sessions == nil { - return nil, ErrUninitializedDB - } - - chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) - if chanDetailsBkt == nil { - return nil, ErrUninitializedDB - } - - sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:]) - if err != nil { - return nil, err - } - - // Get the DB representation of the channel-ID. - _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) - if err != nil { - return nil, err - } +func getRangesWriteBucket(sessionBkt kvdb.RwBucket, dbChanIDBytes []byte) ( + kvdb.RwBucket, error) { sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists( cSessionAckRangeIndex, @@ -1263,10 +1249,23 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { } sessionBkt := sessions.NestedReadBucket(id[:]) - if sessionsBkt == nil { + if sessionBkt == nil { return nil } + // First, account for any rogue updates. + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return err + } + + numAcked += rogueCount + } + + // Then, check if the session-ack-ranges contains any entries + // to account for. sessionAckRanges := sessionBkt.NestedReadBucket( cSessionAckRangeIndex, ) @@ -1546,14 +1545,37 @@ func (c *ClientDB) DeleteSession(id SessionID) error { return err } - // Get the acked updates range index for the session. This is - // used to get the list of channels that the session has updates - // for. ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + + // There is a small chance that the session only contains rogue + // updates. In that case, there will be no ack-ranges index but + // the rogue update count will be equal the MaxUpdates. + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return err + } + + maxUpdates := sess.ClientSessionBody.Policy.MaxUpdates + if rogueCount == uint64(maxUpdates) { + // Do a sanity check to ensure that the acked + // ranges bucket does not exist in this case. + if ackRanges != nil { + return fmt.Errorf("acked updates "+ + "exist for session with a "+ + "max-updates(%d) rogue count", + rogueCount) + } + + return sessionsBkt.DeleteNestedBucket(id[:]) + } + } + + // A session would only be considered closable if it was + // exhausted. Meaning that it should not be the case that it has + // no acked-updates. if ackRanges == nil { - // A session would only be considered closable if it - // was exhausted. Meaning that it should not be the - // case that it has no acked-updates. return fmt.Errorf("cannot delete session %s since it "+ "is not yet exhausted", id) } @@ -1784,6 +1806,22 @@ func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket, return false, nil } + // Either the acked-update bucket should exist _or_ the rogue update + // count must be equal to the session's MaxUpdates value, otherwise + // something is wrong because the above check ensures that the session + // has been exhausted. + rogueCountBytes := sessBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return false, err + } + + if rogueCount == uint64(session.Policy.MaxUpdates) { + return true, nil + } + } + // If the session has no acked-updates, then something is wrong since // the above check ensures that this session has been exhausted meaning // that it should have MaxUpdates acked updates. @@ -2026,18 +2064,92 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } + dbSessionID, dbSessIDBytes, err := getDBSessionID(sessions, *id) + if err != nil { + return err + } + chanID := committedUpdate.BackupID.ChanID height := committedUpdate.BackupID.CommitHeight - // Get the ranges write bucket before getting the range index to - // ensure that the session acks sub-bucket is initialized, so - // that we can insert an entry. - rangesBkt, err := getRangesWriteBucket(tx, *id, chanID) - if err != nil { + // Get the DB representation of the channel-ID. There is a + // chance that the channel corresponding to this update has been + // closed and that the details for this channel no longer exist + // in the tower client DB. In that case, we consider this a + // rogue update and all we do is make sure to keep track of the + // number of rogue updates for this session. + _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) + if errors.Is(err, ErrChannelNotRegistered) { + var ( + count uint64 + err error + ) + + rogueCountBytes := sessionBkt.Get( + cSessionRogueUpdateCount, + ) + if len(rogueCountBytes) != 0 { + count, err = readBigSize(rogueCountBytes) + if err != nil { + return err + } + } + + rogueCount := count + 1 + countBytes, err := writeBigSize(rogueCount) + if err != nil { + return err + } + + err = sessionBkt.Put( + cSessionRogueUpdateCount, countBytes, + ) + if err != nil { + return err + } + + // In the rare chance that this session only has rogue + // updates, we check here if the count is equal to the + // MaxUpdate of the session. If it is, then we mark the + // session as closable. + if rogueCount != uint64(session.Policy.MaxUpdates) { + return nil + } + + // Before we mark the session as closable, we do a + // sanity check to ensure that this session has no + // acked-update index. + sessionAckRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if sessionAckRanges != nil { + return fmt.Errorf("session(%s) has an "+ + "acked ranges index but has a rogue "+ + "count indicating saturation", + session.ID) + } + + closableSessBkt := tx.ReadWriteBucket( + cClosableSessionsBkt, + ) + if closableSessBkt == nil { + return ErrUninitializedDB + } + + var height [4]byte + byteOrder.PutUint32(height[:], 0) + + return closableSessBkt.Put(dbSessIDBytes, height[:]) + } else if err != nil { return err } - dbSessionID, _, err := getDBSessionID(sessions, *id) + // Get the ranges write bucket before getting the range index to + // ensure that the session acks sub-bucket is initialized, so + // that we can insert an entry. + rangesBkt, err := getRangesWriteBucket( + sessionBkt, dbChanIDBytes, + ) if err != nil { return err } @@ -2173,6 +2285,11 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64) // number of updates that the session has for the channel. type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16) +// PerRogueUpdateCountCB describes the signature of a callback function that can +// be called for each session with the number of rogue updates that the session +// has. +type PerRogueUpdateCountCB func(*ClientSession, uint16) + // PerAckedUpdateCB describes the signature of a callback function that can be // called for each of a session's acked updates. type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) @@ -2195,6 +2312,10 @@ type ClientSessionListCfg struct { // channel. PerNumAckedUpdates PerNumAckedUpdatesCB + // PerRogueUpdateCount will, if set, be called with the number of rogue + // updates that the session has backed up. + PerRogueUpdateCount PerRogueUpdateCountCB + // PerMaxHeight will, if set, be called for each of the session's // channels to communicate the highest commit height of updates stored // for that channel. @@ -2242,6 +2363,15 @@ func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption { } } +// WithPerRogueUpdateCount constructs a functional option that will set a +// call-back function to be called with the number of rogue updates that the +// session has backed up. +func WithPerRogueUpdateCount(cb PerRogueUpdateCountCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerRogueUpdateCount = cb + } +} + // WithPerCommittedUpdate constructs a functional option that will set a // call-back function to be called for each of a client's un-acked updates. func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { @@ -2310,7 +2440,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, // provided. err = c.filterClientSessionAcks( sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight, - cfg.PerNumAckedUpdates, + cfg.PerNumAckedUpdates, cfg.PerRogueUpdateCount, ) if err != nil { return nil, err @@ -2368,7 +2498,24 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, // call back if one is provided. func (c *ClientDB) filterClientSessionAcks(sessionBkt, chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB, - perNumAckedUpdates PerNumAckedUpdatesCB) error { + perNumAckedUpdates PerNumAckedUpdatesCB, + perRogueUpdateCount PerRogueUpdateCountCB) error { + + if perRogueUpdateCount != nil { + var ( + count uint64 + err error + ) + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + count, err = readBigSize(rogueCountBytes) + if err != nil { + return err + } + } + + perRogueUpdateCount(s, uint16(count)) + } if perMaxCb == nil && perNumAckedUpdates == nil { return nil diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 5bfb4dab5d..6d11a69728 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/stretchr/testify/require" ) @@ -676,6 +675,98 @@ func testCommitUpdate(h *clientDBHarness) { h.assertUpdates(session.ID, []wtdb.CommittedUpdate{}, nil) } +// testRogueUpdates asserts that rogue updates (updates for channels that are +// backed up after the channel has been closed and the channel details deleted +// from the DB) are handled correctly. +func testRogueUpdates(h *clientDBHarness) { + const maxUpdates = 5 + + tower := h.newTower() + + // Create and insert a new session. + session1 := h.randSession(h.t, tower.ID, maxUpdates) + h.insertSession(session1, nil) + + // Create a new channel and register it. + chanID1 := randChannelID(h.t) + h.registerChan(chanID1, nil, nil) + + // Num acked updates should be 0. + require.Zero(h.t, h.numAcked(&session1.ID, nil)) + + // Commit and ACK enough updates for this channel to fill the session. + for i := 1; i <= maxUpdates; i++ { + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(i), uint64(i), + ) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil) + } + + // Num acked updates should now be 5. + require.EqualValues(h.t, maxUpdates, h.numAcked(&session1.ID, nil)) + + // Commit one more update for the channel but this time do not ACK it. + // This update will be put in a new session since the previous one has + // been exhausted. + session2 := h.randSession(h.t, tower.ID, maxUpdates) + sess2Seq := 1 + h.insertSession(session2, nil) + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(sess2Seq), uint64(maxUpdates+1), + ) + lastApplied := h.commitUpdate(&session2.ID, update, nil) + + // Session 2 should not have any acked updates yet. + require.Zero(h.t, h.numAcked(&session2.ID, nil)) + + // There should currently be no closable sessions. + require.Empty(h.t, h.listClosableSessions(nil)) + + // Now mark the channel as closed. + h.markChannelClosed(chanID1, 1, nil) + + // Assert that session 1 is now seen as closable. + closableSessionsMap := h.listClosableSessions(nil) + require.Len(h.t, closableSessionsMap, 1) + _, ok := closableSessionsMap[session1.ID] + require.True(h.t, ok) + + // Delete session 1. + h.deleteSession(session1.ID, nil) + + // Now try to ACK the update for the channel. This should succeed and + // the update should be considered a rogue update. + h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil) + + // Show that the number of acked updates is now 1. + require.EqualValues(h.t, 1, h.numAcked(&session2.ID, nil)) + + // We also want to test the extreme case where all the updates for a + // particular session are rogue updates. In this case, the session + // should be seen as closable if it is saturated. + + // First show that the session is not yet considered closable. + require.Empty(h.t, h.listClosableSessions(nil)) + + // Then, let's continue adding rogue updates for the closed channel to + // session 2. + for i := maxUpdates + 2; i <= maxUpdates*2; i++ { + sess2Seq++ + + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(sess2Seq), uint64(i), + ) + lastApplied := h.commitUpdate(&session2.ID, update, nil) + h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil) + } + + // At this point, session 2 is saturated with rogue updates. Assert that + // it is now closable. + closableSessionsMap = h.listClosableSessions(nil) + require.Len(h.t, closableSessionsMap, 1) +} + // testMarkChannelClosed asserts the behaviour of MarkChannelClosed. func testMarkChannelClosed(h *clientDBHarness) { tower := h.newTower() @@ -763,7 +854,7 @@ func testMarkChannelClosed(h *clientDBHarness) { require.EqualValues(h.t, 4, lastApplied) h.ackUpdate(&session1.ID, 5, 5, nil) - // The session is no exhausted. + // The session is now exhausted. // If we now close channel 5, session 1 should still not be closable // since it has an update for channel 6 which is still open. sl = h.markChannelClosed(chanID5, 1, nil) @@ -964,12 +1055,6 @@ func TestClientDB(t *testing.T) { return db }, }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, } tests := []struct { @@ -1008,6 +1093,10 @@ func TestClientDB(t *testing.T) { name: "mark channel closed", run: testMarkChannelClosed, }, + { + name: "rogue updates", + run: testRogueUpdates, + }, } for _, database := range dbs { @@ -1073,6 +1162,34 @@ func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID, } } +// randCommittedUpdateForChanWithHeight generates a random committed update for +// the given channel ID using the given commit height. +func randCommittedUpdateForChanWithHeight(t *testing.T, chanID lnwire.ChannelID, + seqNum uint16, height uint64) *wtdb.CommittedUpdate { + + t.Helper() + + var hint blob.BreachHint + _, err := io.ReadFull(crand.Reader, hint[:]) + require.NoError(t, err) + + encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) + _, err = io.ReadFull(crand.Reader, encBlob) + require.NoError(t, err) + + return &wtdb.CommittedUpdate{ + SeqNum: seqNum, + CommittedUpdateBody: wtdb.CommittedUpdateBody{ + BackupID: wtdb.BackupID{ + ChanID: chanID, + CommitHeight: height, + }, + Hint: hint, + EncryptedBlob: encBlob, + }, + } +} + func (h *clientDBHarness) randSession(t *testing.T, towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession { diff --git a/watchtower/wtdb/queue_test.go b/watchtower/wtdb/queue_test.go index a864125cf7..02c7b272cb 100644 --- a/watchtower/wtdb/queue_test.go +++ b/watchtower/wtdb/queue_test.go @@ -4,9 +4,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/stretchr/testify/require" ) @@ -15,53 +13,24 @@ import ( func TestDiskQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init clientDBInit - }{ - { - name: "bbolt", - init: func(t *testing.T) wtclient.DB { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - // Construct the ClientDB. - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) - - return db - }, - }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + // Construct the ClientDB. + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) - testQueue(t, db.init(t)) - }) - } -} + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) -func testQueue(t *testing.T, db wtclient.DB) { namespace := []byte("test-namespace") queue := db.GetDBQueue(namespace) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go deleted file mode 100644 index f5625d35b3..0000000000 --- a/watchtower/wtmock/client_db.go +++ /dev/null @@ -1,887 +0,0 @@ -package wtmock - -import ( - "encoding/binary" - "net" - "sync" - "sync/atomic" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/watchtower/blob" - "github.com/lightningnetwork/lnd/watchtower/wtdb" -) - -var byteOrder = binary.BigEndian - -type towerPK [33]byte - -type keyIndexKey struct { - towerID wtdb.TowerID - blobType blob.Type -} - -type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex - -type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore - -type channel struct { - summary *wtdb.ClientChanSummary - closedHeight uint32 - sessions map[wtdb.SessionID]bool -} - -// ClientDB is a mock, in-memory database or testing the watchtower client -// behavior. -type ClientDB struct { - nextTowerID uint64 // to be used atomically - - mu sync.Mutex - channels map[lnwire.ChannelID]*channel - activeSessions map[wtdb.SessionID]wtdb.ClientSession - ackedUpdates rangeIndexArrayMap - persistedAckedUpdates rangeIndexKVStore - committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower - closableSessions map[wtdb.SessionID]uint32 - - nextIndex uint32 - indexes map[keyIndexKey]uint32 - legacyIndexes map[wtdb.TowerID]uint32 - - queues map[string]wtdb.Queue[*wtdb.BackupID] -} - -// NewClientDB initializes a new mock ClientDB. -func NewClientDB() *ClientDB { - return &ClientDB{ - channels: make(map[lnwire.ChannelID]*channel), - activeSessions: make( - map[wtdb.SessionID]wtdb.ClientSession, - ), - ackedUpdates: make(rangeIndexArrayMap), - persistedAckedUpdates: make(rangeIndexKVStore), - committedUpdates: make( - map[wtdb.SessionID][]wtdb.CommittedUpdate, - ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), - closableSessions: make(map[wtdb.SessionID]uint32), - queues: make(map[string]wtdb.Queue[*wtdb.BackupID]), - } -} - -// CreateTower initialize an address record used to communicate with a -// watchtower. Each Tower is assigned a unique ID, that is used to amortize -// storage costs of the public key when used by multiple sessions. If the tower -// already exists, the address is appended to the list of all addresses used to -// that tower previously and its corresponding sessions are marked as active. -func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var towerPubKey towerPK - copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) - - var tower *wtdb.Tower - towerID, ok := m.towerIndex[towerPubKey] - if ok { - tower = m.towers[towerID] - tower.AddAddress(lnAddr.Address) - - towerSessions, err := m.listClientSessions(&towerID) - if err != nil { - return nil, err - } - for id, session := range towerSessions { - session.Status = wtdb.CSessionActive - m.activeSessions[id] = *session - } - } else { - towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) - tower = &wtdb.Tower{ - ID: towerID, - IdentityKey: lnAddr.IdentityKey, - Addresses: []net.Addr{lnAddr.Address}, - } - } - - m.towerIndex[towerPubKey] = towerID - m.towers[towerID] = tower - - return copyTower(tower), nil -} - -// RemoveTower modifies a tower's record within the database. If an address is -// provided, then _only_ the address record should be removed from the tower's -// persisted state. Otherwise, we'll attempt to mark the tower as inactive by -// marking all of its sessions inactive. If any of its sessions has unacked -// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have -// any sessions at all, it'll be completely removed from the database. -// -// NOTE: An error is not returned if the tower doesn't exist. -func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { - m.mu.Lock() - defer m.mu.Unlock() - - tower, err := m.loadTower(pubKey) - if err == wtdb.ErrTowerNotFound { - return nil - } - if err != nil { - return err - } - - if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr - } - m.towers[tower.ID] = tower - return nil - } - - towerSessions, err := m.listClientSessions(&tower.ID) - if err != nil { - return err - } - if len(towerSessions) == 0 { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - delete(m.towerIndex, towerPK) - delete(m.towers, tower.ID) - return nil - } - - for id, session := range towerSessions { - if len(m.committedUpdates[session.ID]) > 0 { - return wtdb.ErrTowerUnackedUpdates - } - session.Status = wtdb.CSessionInactive - m.activeSessions[id] = *session - } - - return nil -} - -// LoadTower retrieves a tower by its public key. -func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - return m.loadTower(pubKey) -} - -// loadTower retrieves a tower by its public key. -// -// NOTE: This method requires the database's lock to be acquired. -func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - - towerID, ok := m.towerIndex[towerPK] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - tower, ok := m.towers[towerID] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - - return copyTower(tower), nil -} - -// LoadTowerByID retrieves a tower by its tower ID. -func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if tower, ok := m.towers[towerID]; ok { - return copyTower(tower), nil - } - - return nil, wtdb.ErrTowerNotFound -} - -// ListTowers retrieves the list of towers available within the database. -func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - towers := make([]*wtdb.Tower, 0, len(m.towers)) - for _, tower := range m.towers { - towers = append(towers, copyTower(tower)) - } - - return towers, nil -} - -// MarkBackupIneligible records that particular commit height is ineligible for -// backup. This allows the client to track which updates it should not attempt -// to retry after startup. -func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { - return nil -} - -// ListClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - return m.listClientSessions(tower, opts...) -} - -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, session := range m.activeSessions { - session := session - if tower != nil && *tower != session.TowerID { - continue - } - - if cfg.PreEvaluateFilterFn != nil && - !cfg.PreEvaluateFilterFn(&session) { - - continue - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight( - &session, chanID, index.MaxHeight(), - ) - } - } - - if cfg.PerNumAckedUpdates != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerNumAckedUpdates( - &session, chanID, - uint16(index.NumInSet()), - ) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - if cfg.PostEvaluateFilterFn != nil && - !cfg.PostEvaluateFilterFn(&session) { - - continue - } - - sessions[session.ID] = &session - } - - return sessions, nil -} - -// FetchSessionCommittedUpdates retrieves the current set of un-acked updates -// of the given session. -func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( - []wtdb.CommittedUpdate, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - updates, ok := m.committedUpdates[*id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - return updates, nil -} - -// IsAcked returns true if the given backup has been backed up using the given -// session. -func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, - error) { - - m.mu.Lock() - defer m.mu.Unlock() - - index, ok := m.ackedUpdates[*id][backupID.ChanID] - if !ok { - return false, nil - } - - return index.IsInIndex(backupID.CommitHeight), nil -} - -// NumAckedUpdates returns the number of backups that have been successfully -// backed up using the given session. -func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var numAcked uint64 - - for _, index := range m.ackedUpdates[*id] { - numAcked += index.NumInSet() - } - - return numAcked, nil -} - -// CreateClientSession records a newly negotiated client session in the set of -// active sessions. The session can be identified by its SessionID. -func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { - m.mu.Lock() - defer m.mu.Unlock() - - // Ensure that we aren't overwriting an existing session. - if _, ok := m.activeSessions[session.ID]; ok { - return wtdb.ErrClientSessionAlreadyExists - } - - key := keyIndexKey{ - towerID: session.TowerID, - blobType: session.Policy.BlobType, - } - - // Ensure that a session key index has been reserved for this tower. - keyIndex, err := m.getSessionKeyIndex(key) - if err != nil { - return err - } - - // Ensure that the session's index matches the reserved index. - if keyIndex != session.KeyIndex { - return wtdb.ErrIncorrectKeyIndex - } - - // Remove the key index reservation for this tower. Once committed, this - // permits us to create another session with this tower. - delete(m.indexes, key) - if key.blobType == blob.TypeAltruistCommit { - delete(m.legacyIndexes, key.towerID) - } - - m.activeSessions[session.ID] = wtdb.ClientSession{ - ID: session.ID, - ClientSessionBody: wtdb.ClientSessionBody{ - SeqNum: session.SeqNum, - TowerLastApplied: session.TowerLastApplied, - TowerID: session.TowerID, - KeyIndex: session.KeyIndex, - Policy: session.Policy, - RewardPkScript: cloneBytes(session.RewardPkScript), - }, - } - m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex) - m.persistedAckedUpdates[session.ID] = make( - map[lnwire.ChannelID]*mockKVStore, - ) - m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) - - return nil -} - -// NextSessionKeyIndex reserves a new session key derivation index for a -// particular tower id. The index is reserved for that tower until -// CreateClientSession is invoked for that tower and index, at which point a new -// index for that tower can be reserved. Multiple calls to this method before -// CreateClientSession is invoked should return the same index unless forceNext -// is set to true. -func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, blobType blob.Type, - forceNext bool) (uint32, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - key := keyIndexKey{ - towerID: towerID, - blobType: blobType, - } - - if !forceNext { - if index, err := m.getSessionKeyIndex(key); err == nil { - return index, nil - } - } - - // By default, we use the next available bucket sequence as the key - // index. But if forceNext is true, then it is assumed that some data - // loss occurred and so the sequence is incremented a by a jump of 1000 - // so that we can arrive at a brand new key index quicker. - nextIndex := m.nextIndex + 1 - if forceNext { - nextIndex = m.nextIndex + 1000 - } - m.nextIndex = nextIndex - m.indexes[key] = nextIndex - - return nextIndex, nil -} - -func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) { - if index, ok := m.indexes[key]; ok { - return index, nil - } - - if key.blobType == blob.TypeAltruistCommit { - if index, ok := m.legacyIndexes[key.towerID]; ok { - return index, nil - } - } - - return 0, wtdb.ErrNoReservedKeyIndex -} - -// CommitUpdate persists the CommittedUpdate provided in the slot for (session, -// seqNum). This allows the client to retransmit this update on startup. -func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, - update *wtdb.CommittedUpdate) (uint16, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return 0, wtdb.ErrClientSessionNotFound - } - - // Check if an update has already been committed for this state. - for _, dbUpdate := range m.committedUpdates[session.ID] { - if dbUpdate.SeqNum == update.SeqNum { - // If the breach hint matches, we'll just return the - // last applied value so the client can retransmit. - if dbUpdate.Hint == update.Hint { - return session.TowerLastApplied, nil - } - - // Otherwise, fail since the breach hint doesn't match. - return 0, wtdb.ErrUpdateAlreadyCommitted - } - } - - // Sequence number must increment. - if update.SeqNum != session.SeqNum+1 { - return 0, wtdb.ErrCommitUnorderedUpdate - } - - // Save the update and increment the sequence number. - m.committedUpdates[session.ID] = append( - m.committedUpdates[session.ID], *update, - ) - session.SeqNum++ - m.activeSessions[*id] = session - - return session.TowerLastApplied, nil -} - -// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This -// removes the update from the set of committed updates, and validates the -// lastApplied value returned from the tower. -func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, - lastApplied uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Ensure the returned last applied value does not exceed the highest - // allocated sequence number. - if lastApplied > session.SeqNum { - return wtdb.ErrUnallocatedLastApplied - } - - // Ensure the last applied value isn't lower than a previous one sent by - // the tower. - if lastApplied < session.TowerLastApplied { - return wtdb.ErrLastAppliedReversion - } - - // Retrieve the committed update, failing if none is found. We should - // only receive acks for state updates that we send. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Add sessionID to channel. - channel, ok := m.channels[update.BackupID.ChanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - channel.sessions[*id] = true - - // Remove the committed update from disk and mark the update as - // acked. The tower last applied value is also recorded to send - // along with the next update. - copy(updates[:i], updates[i+1:]) - updates[len(updates)-1] = wtdb.CommittedUpdate{} - m.committedUpdates[session.ID] = updates[:len(updates)-1] - - chanID := update.BackupID.ChanID - if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok { - index, err := wtdb.NewRangeIndex(nil) - if err != nil { - return err - } - - m.ackedUpdates[*id][chanID] = index - m.persistedAckedUpdates[*id][chanID] = newMockKVStore() - } - - err := m.ackedUpdates[*id][chanID].Add( - update.BackupID.CommitHeight, - m.persistedAckedUpdates[*id][chanID], - ) - if err != nil { - return err - } - - session.TowerLastApplied = lastApplied - - m.activeSessions[*id] = session - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// GetDBQueue returns a BackupID Queue instance under the given name space. -func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] { - m.mu.Lock() - defer m.mu.Unlock() - - if q, ok := m.queues[string(namespace)]; ok { - return q - } - - q := NewQueueDB[*wtdb.BackupID]() - m.queues[string(namespace)] = q - - return q -} - -// DeleteCommittedUpdate deletes the committed update with the given sequence -// number from the given session. -func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID, - seqNum uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Retrieve the committed update, failing if none is found. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Remove the committed update from "disk". - updates = append(updates[:i], updates[i+1:]...) - m.committedUpdates[session.ID] = updates - - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// ListClosableSessions fetches and returns the IDs for all sessions marked as -// closable. -func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { - m.mu.Lock() - defer m.mu.Unlock() - - cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) - for id, height := range m.closableSessions { - cs[id] = height - } - - return cs, nil -} - -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { - m.mu.Lock() - defer m.mu.Unlock() - - summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, channel := range m.channels { - // Don't load the channel if it has been marked as closed. - if channel.closedHeight > 0 { - continue - } - - summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes( - channel.summary.SweepPkScript, - ), - } - } - - return summaries, nil -} - -// MarkChannelClosed will mark a registered channel as closed by setting -// its closed-height as the given block height. It returns a list of -// session IDs for sessions that are now considered closable due to the -// close of this channel. -func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, - blockHeight uint32) ([]wtdb.SessionID, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - channel, ok := m.channels[chanID] - if !ok { - return nil, wtdb.ErrChannelNotRegistered - } - - // If there are no sessions for this channel, the channel details can be - // deleted. - if len(channel.sessions) == 0 { - delete(m.channels, chanID) - return nil, nil - } - - // Mark the channel as closed. - channel.closedHeight = blockHeight - - // Now iterate through all the sessions of the channel to check if any - // of them are closeable. - var closableSessions []wtdb.SessionID - for sessID := range channel.sessions { - isClosable, err := m.isSessionClosable(sessID) - if err != nil { - return nil, err - } - - if !isClosable { - continue - } - - closableSessions = append(closableSessions, sessID) - - // Add session to "closableSessions" list and add the block - // height that this last channel was closed in. This will be - // used in future to determine when we should delete the - // session. - m.closableSessions[sessID] = blockHeight - } - - return closableSessions, nil -} - -// isSessionClosable returns true if a session is considered closable. A session -// is considered closable only if: -// 1) It has no un-acked updates -// 2) It is exhausted (ie it cant accept any more updates) -// 3) All the channels that it has acked-updates for are closed. -func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { - // The session is not closable if it has un-acked updates. - if len(m.committedUpdates[id]) > 0 { - return false, nil - } - - sess, ok := m.activeSessions[id] - if !ok { - return false, wtdb.ErrClientSessionNotFound - } - - // The session is not closable if it is not yet exhausted. - if sess.SeqNum != sess.Policy.MaxUpdates { - return false, nil - } - - // Iterate over each of the channels that the session has acked-updates - // for. If any of those channels are not closed, then the session is - // not yet closable. - for chanID := range m.ackedUpdates[id] { - channel, ok := m.channels[chanID] - if !ok { - continue - } - - // Channel is not yet closed, and so we can not yet delete the - // session. - if channel.closedHeight == 0 { - return false, nil - } - } - - return true, nil -} - -// GetClientSession loads the ClientSession with the given ID from the DB. -func (m *ClientDB) GetClientSession(id wtdb.SessionID, - opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - session, ok := m.activeSessions[id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - return &session, nil -} - -// DeleteSession can be called when a session should be deleted from the DB. -// All references to the session will also be deleted from the DB. Note that a -// session will only be deleted if it is considered closable. -func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { - m.mu.Lock() - defer m.mu.Unlock() - - _, ok := m.closableSessions[id] - if !ok { - return wtdb.ErrSessionNotClosable - } - - // For each of the channels, delete the session ID entry. - for chanID := range m.ackedUpdates[id] { - c, ok := m.channels[chanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - - delete(c.sessions, id) - } - - delete(m.closableSessions, id) - delete(m.activeSessions, id) - - return nil -} - -// RegisterChannel registers a channel for use within the client database. For -// now, all that is stored in the channel summary is the sweep pkscript that -// we'd like any tower sweeps to pay into. In the future, this will be extended -// to contain more info to allow the client efficiently request historical -// states to be backed up under the client's active policy. -func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, - sweepPkScript []byte) error { - - m.mu.Lock() - defer m.mu.Unlock() - - if _, ok := m.channels[chanID]; ok { - return wtdb.ErrChannelAlreadyRegistered - } - - m.channels[chanID] = &channel{ - summary: &wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), - }, - sessions: make(map[wtdb.SessionID]bool), - } - - return nil -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - - bb := make([]byte, len(b)) - copy(bb, b) - - return bb -} - -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ - ID: tower.ID, - IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), - } - copy(t.Addresses, tower.Addresses) - - return t -} - -type mockKVStore struct { - kv map[uint64]uint64 - - err error -} - -func newMockKVStore() *mockKVStore { - return &mockKVStore{ - kv: make(map[uint64]uint64), - } -} - -func (m *mockKVStore) Put(key, value []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - v := byteOrder.Uint64(value) - - m.kv[k] = v - - return nil -} - -func (m *mockKVStore) Delete(key []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - delete(m.kv, k) - - return nil -}