From a842915ba216d5eeac20b31eada51767ccd5536b Mon Sep 17 00:00:00 2001 From: sstone Date: Mon, 6 Jan 2025 14:27:47 +0100 Subject: [PATCH] Use specific segwit and taproot input info types We now use specific subtypes for segwit inputs (which include a redeem script) and taproot inputs (which include a script tree and an internal key). Older codecs have been modified to always return a SegwitInput. v4 codec is modified and uses an empty redeem script as a marker to specify that a script tree is being used, which makes it compatible with the current v4 codec. --- .../fr/acinq/eclair/channel/Commitments.scala | 5 +- .../fr/acinq/eclair/channel/Helpers.scala | 4 +- .../channel/publish/ReplaceableTxFunder.scala | 4 +- .../eclair/transactions/Transactions.scala | 140 +++++++++++------- .../channel/version0/ChannelCodecs0.scala | 8 +- .../channel/version1/ChannelCodecs1.scala | 8 +- .../channel/version2/ChannelCodecs2.scala | 8 +- .../channel/version3/ChannelCodecs3.scala | 8 +- .../channel/version4/ChannelCodecs4.scala | 22 ++- .../publish/ReplaceableTxFunderSpec.scala | 2 +- .../eclair/transactions/TestVectorsSpec.scala | 8 +- .../channel/version4/ChannelCodecs4Spec.scala | 5 + 12 files changed, 130 insertions(+), 92 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala index 36e5b0dfac..8329a3534d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Commitments.scala @@ -1142,7 +1142,10 @@ case class Commitments(params: ChannelParams, val localFundingKey = keyManager.fundingPublicKey(params.localParams.fundingKeyPath, commitment.fundingTxIndex).publicKey val remoteFundingKey = commitment.remoteFundingPubKey val fundingScript = Script.write(Scripts.multiSig2of2(localFundingKey, remoteFundingKey)) - commitment.commitInput.redeemScriptOrScriptTree == Left(fundingScript) + commitment.commitInput match { + case InputInfo.SegwitInput(_, _, redeemScript) => redeemScript == fundingScript + case _ => false + } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala index 12b3b2cd90..a99f9af743 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Helpers.scala @@ -378,10 +378,10 @@ object Helpers { def makeFundingPubKeyScript(localFundingKey: PublicKey, remoteFundingKey: PublicKey): ByteVector = write(pay2wsh(multiSig2of2(localFundingKey, remoteFundingKey))) - def makeFundingInputInfo(fundingTxId: TxId, fundingTxOutputIndex: Int, fundingSatoshis: Satoshi, fundingPubkey1: PublicKey, fundingPubkey2: PublicKey): InputInfo = { + def makeFundingInputInfo(fundingTxId: TxId, fundingTxOutputIndex: Int, fundingSatoshis: Satoshi, fundingPubkey1: PublicKey, fundingPubkey2: PublicKey): InputInfo.SegwitInput = { val fundingScript = multiSig2of2(fundingPubkey1, fundingPubkey2) val fundingTxOut = TxOut(fundingSatoshis, pay2wsh(fundingScript)) - InputInfo(OutPoint(fundingTxId, fundingTxOutputIndex), fundingTxOut, write(fundingScript)) + InputInfo.SegwitInput(OutPoint(fundingTxId, fundingTxOutputIndex), fundingTxOut, write(fundingScript)) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunder.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunder.scala index e9d80acf2b..48a411ff09 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunder.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunder.scala @@ -358,8 +358,8 @@ private class ReplaceableTxFunder(nodeParams: NodeParams, import fr.acinq.bitcoin.scalacompat.KotlinUtils._ // We create a PSBT with the non-wallet input already signed: - val witnessScript = locallySignedTx.txInfo.input.redeemScriptOrScriptTree match { - case Left(redeemScript) => fr.acinq.bitcoin.Script.parse(redeemScript) + val witnessScript = locallySignedTx.txInfo.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => fr.acinq.bitcoin.Script.parse(redeemScript) case _ => null } val psbt = new Psbt(locallySignedTx.txInfo.tx) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/transactions/Transactions.scala b/eclair-core/src/main/scala/fr/acinq/eclair/transactions/Transactions.scala index cc4b0b2960..48e423d3fc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/transactions/Transactions.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/transactions/Transactions.scala @@ -26,6 +26,7 @@ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.fee.{ConfirmationTarget, FeeratePerKw} import fr.acinq.eclair.transactions.CommitmentOutput._ import fr.acinq.eclair.transactions.Scripts._ +import fr.acinq.eclair.transactions.Transactions.InputInfo.SegwitInput import fr.acinq.eclair.wire.protocol.UpdateAddHtlc import scodec.bits.ByteVector @@ -102,14 +103,18 @@ object Transactions { val publicKeyScript: ByteVector = Script.write(Script.pay2tr(internalKey, Some(scriptTree))) } - case class InputInfo(outPoint: OutPoint, txOut: TxOut, redeemScriptOrScriptTree: Either[ByteVector, ScriptTreeAndInternalKey]) { - val redeemScriptOrEmptyScript: ByteVector = redeemScriptOrScriptTree.swap.getOrElse(ByteVector.empty) // TODO: use the actual script tree for taproot transactions, once we implement them + sealed trait InputInfo { + val outPoint: OutPoint + val txOut: TxOut } object InputInfo { - def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) = new InputInfo(outPoint, txOut, Left(redeemScript)) - def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: Seq[ScriptElt]) = new InputInfo(outPoint, txOut, Left(Script.write(redeemScript))) - def apply(outPoint: OutPoint, txOut: TxOut, scriptTree: ScriptTreeAndInternalKey) = new InputInfo(outPoint, txOut, Right(scriptTree)) + case class SegwitInput(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) extends InputInfo + case class TaprootInput(outPoint: OutPoint, txOut: TxOut, scriptTreeAndInternalKey: ScriptTreeAndInternalKey) extends InputInfo + + def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector): SegwitInput = SegwitInput(outPoint, txOut, redeemScript) + def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: Seq[ScriptElt]): SegwitInput = SegwitInput(outPoint, txOut, Script.write(redeemScript)) + def apply(outPoint: OutPoint, txOut: TxOut, scriptTree: ScriptTreeAndInternalKey): TaprootInput = TaprootInput(outPoint, txOut, scriptTree) } /** Owner of a given transaction (local/remote). */ @@ -138,24 +143,29 @@ object Transactions { sign(key, sighash(txOwner, commitmentFormat)) } - def sign(key: PrivateKey, sighashType: Int): ByteVector64 = { - // NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the - // signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL. - val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint) - val sigDER = Transaction.signInput(tx, inputIndex, input.redeemScriptOrEmptyScript, sighashType, input.txOut.amount, SIGVERSION_WITNESS_V0, key) - val sig64 = Crypto.der2compact(sigDER) - sig64 + def sign(key: PrivateKey, sighashType: Int): ByteVector64 = input match { + case _:InputInfo.TaprootInput => ByteVector64.Zeroes + case InputInfo.SegwitInput(outPoint, txOut, redeemScript) => + // NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the + // signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL. + val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint) + val sigDER = Transaction.signInput(tx, inputIndex, redeemScript, sighashType, txOut.amount, SIGVERSION_WITNESS_V0, key) + val sig64 = Crypto.der2compact(sigDER) + sig64 } - def checkSig(sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = { - val sighash = this.sighash(txOwner, commitmentFormat) - val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint) - if (inputIndex >= 0) { - val data = Transaction.hashForSigning(tx, inputIndex, input.redeemScriptOrEmptyScript, sighash, input.txOut.amount, SIGVERSION_WITNESS_V0) - Crypto.verifySignature(data, sig, pubKey) - } else { - false - } + def checkSig(sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = input match { + + case _:InputInfo.TaprootInput => false + case InputInfo.SegwitInput(outPoint, txOut, redeemScript) => + val sighash = this.sighash(txOwner, commitmentFormat) + val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint) + if (inputIndex >= 0) { + val data = Transaction.hashForSigning(tx, inputIndex, redeemScript, sighash, txOut.amount, SIGVERSION_WITNESS_V0) + Crypto.verifySignature(data, sig, pubKey) + } else { + false + } } } @@ -983,34 +993,46 @@ object Transactions { commitTx.copy(tx = commitTx.tx.updateWitness(0, witness)) } - def addSigs(mainPenaltyTx: MainPenaltyTx, revocationSig: ByteVector64): MainPenaltyTx = { - val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, mainPenaltyTx.input.redeemScriptOrEmptyScript) - mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0, witness)) + def addSigs(mainPenaltyTx: MainPenaltyTx, revocationSig: ByteVector64): MainPenaltyTx = mainPenaltyTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript) + mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0, witness)) + case _ => mainPenaltyTx } - def addSigs(htlcPenaltyTx: HtlcPenaltyTx, revocationSig: ByteVector64, revocationPubkey: PublicKey): HtlcPenaltyTx = { - val witness = Scripts.witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, htlcPenaltyTx.input.redeemScriptOrEmptyScript) - htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0, witness)) + def addSigs(htlcPenaltyTx: HtlcPenaltyTx, revocationSig: ByteVector64, revocationPubkey: PublicKey): HtlcPenaltyTx = htlcPenaltyTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = Scripts.witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, redeemScript) + htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0, witness)) + case _ => htlcPenaltyTx } - def addSigs(htlcSuccessTx: HtlcSuccessTx, localSig: ByteVector64, remoteSig: ByteVector64, paymentPreimage: ByteVector32, commitmentFormat: CommitmentFormat): HtlcSuccessTx = { - val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, htlcSuccessTx.input.redeemScriptOrEmptyScript, commitmentFormat) - htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0, witness)) + def addSigs(htlcSuccessTx: HtlcSuccessTx, localSig: ByteVector64, remoteSig: ByteVector64, paymentPreimage: ByteVector32, commitmentFormat: CommitmentFormat): HtlcSuccessTx = htlcSuccessTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, redeemScript, commitmentFormat) + htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0, witness)) + case _ => htlcSuccessTx } - def addSigs(htlcTimeoutTx: HtlcTimeoutTx, localSig: ByteVector64, remoteSig: ByteVector64, commitmentFormat: CommitmentFormat): HtlcTimeoutTx = { - val witness = witnessHtlcTimeout(localSig, remoteSig, htlcTimeoutTx.input.redeemScriptOrEmptyScript, commitmentFormat) - htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0, witness)) + def addSigs(htlcTimeoutTx: HtlcTimeoutTx, localSig: ByteVector64, remoteSig: ByteVector64, commitmentFormat: CommitmentFormat): HtlcTimeoutTx = htlcTimeoutTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessHtlcTimeout(localSig, remoteSig, redeemScript, commitmentFormat) + htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0, witness)) + case _ => htlcTimeoutTx } - def addSigs(claimHtlcSuccessTx: ClaimHtlcSuccessTx, localSig: ByteVector64, paymentPreimage: ByteVector32): ClaimHtlcSuccessTx = { - val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, claimHtlcSuccessTx.input.redeemScriptOrEmptyScript) - claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0, witness)) + def addSigs(claimHtlcSuccessTx: ClaimHtlcSuccessTx, localSig: ByteVector64, paymentPreimage: ByteVector32): ClaimHtlcSuccessTx = claimHtlcSuccessTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, redeemScript) + claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0, witness)) + case _ => claimHtlcSuccessTx } - def addSigs(claimHtlcTimeoutTx: ClaimHtlcTimeoutTx, localSig: ByteVector64): ClaimHtlcTimeoutTx = { - val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, claimHtlcTimeoutTx.input.redeemScriptOrEmptyScript) - claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0, witness)) + def addSigs(claimHtlcTimeoutTx: ClaimHtlcTimeoutTx, localSig: ByteVector64): ClaimHtlcTimeoutTx = claimHtlcTimeoutTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, redeemScript) + claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0, witness)) + case _ => claimHtlcTimeoutTx } def addSigs(claimP2WPKHOutputTx: ClaimP2WPKHOutputTx, localPaymentPubkey: PublicKey, localSig: ByteVector64): ClaimP2WPKHOutputTx = { @@ -1018,29 +1040,39 @@ object Transactions { claimP2WPKHOutputTx.copy(tx = claimP2WPKHOutputTx.tx.updateWitness(0, witness)) } - def addSigs(claimRemoteDelayedOutputTx: ClaimRemoteDelayedOutputTx, localSig: ByteVector64): ClaimRemoteDelayedOutputTx = { - val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, claimRemoteDelayedOutputTx.input.redeemScriptOrEmptyScript) - claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0, witness)) + def addSigs(claimRemoteDelayedOutputTx: ClaimRemoteDelayedOutputTx, localSig: ByteVector64): ClaimRemoteDelayedOutputTx = claimRemoteDelayedOutputTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, redeemScript) + claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0, witness)) + case _ => claimRemoteDelayedOutputTx } - def addSigs(claimDelayedOutputTx: ClaimLocalDelayedOutputTx, localSig: ByteVector64): ClaimLocalDelayedOutputTx = { - val witness = witnessToLocalDelayedAfterDelay(localSig, claimDelayedOutputTx.input.redeemScriptOrEmptyScript) - claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0, witness)) + def addSigs(claimDelayedOutputTx: ClaimLocalDelayedOutputTx, localSig: ByteVector64): ClaimLocalDelayedOutputTx = claimDelayedOutputTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript) + claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0, witness)) + case _ => claimDelayedOutputTx } - def addSigs(htlcDelayedTx: HtlcDelayedTx, localSig: ByteVector64): HtlcDelayedTx = { - val witness = witnessToLocalDelayedAfterDelay(localSig, htlcDelayedTx.input.redeemScriptOrEmptyScript) - htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0, witness)) + def addSigs(htlcDelayedTx: HtlcDelayedTx, localSig: ByteVector64): HtlcDelayedTx = htlcDelayedTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript) + htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0, witness)) + case _ => htlcDelayedTx } - def addSigs(claimAnchorOutputTx: ClaimLocalAnchorOutputTx, localSig: ByteVector64): ClaimLocalAnchorOutputTx = { - val witness = witnessAnchor(localSig, claimAnchorOutputTx.input.redeemScriptOrEmptyScript) - claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0, witness)) + def addSigs(claimAnchorOutputTx: ClaimLocalAnchorOutputTx, localSig: ByteVector64): ClaimLocalAnchorOutputTx = claimAnchorOutputTx.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = witnessAnchor(localSig, redeemScript) + claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0, witness)) + case _ => claimAnchorOutputTx } - def addSigs(claimHtlcDelayedPenalty: ClaimHtlcDelayedOutputPenaltyTx, revocationSig: ByteVector64): ClaimHtlcDelayedOutputPenaltyTx = { - val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, claimHtlcDelayedPenalty.input.redeemScriptOrEmptyScript) - claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0, witness)) + def addSigs(claimHtlcDelayedPenalty: ClaimHtlcDelayedOutputPenaltyTx, revocationSig: ByteVector64): ClaimHtlcDelayedOutputPenaltyTx = claimHtlcDelayedPenalty.input match { + case InputInfo.SegwitInput(_, _, redeemScript) => + val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript) + claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0, witness)) + case _ => claimHtlcDelayedPenalty } def addSigs(closingTx: ClosingTx, localFundingPubkey: PublicKey, remoteFundingPubkey: PublicKey, localSig: ByteVector64, remoteSig: ByteVector64): ClosingTx = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version0/ChannelCodecs0.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version0/ChannelCodecs0.scala index 4e08376c92..375e9afd6e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version0/ChannelCodecs0.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version0/ChannelCodecs0.scala @@ -125,14 +125,12 @@ private[channel] object ChannelCodecs0 { closingTx => closingTx.tx ) - private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) - - private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = ( + private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = ( ("outPoint" | outPointCodec) :: ("txOut" | txOutCodec) :: - ("redeemScript" | varsizebinarydata)).as[InputInfoLegacy] + ("redeemScript" | varsizebinarydata)).as[InputInfo.SegwitInput].decodeOnly - val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.map(legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript))).decodeOnly + val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo] private val defaultConfirmationTarget: Codec[ConfirmationTarget.Absolute] = provide(ConfirmationTarget.Absolute(BlockHeight(0))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1.scala index 8e2f22056e..4bfa88f80a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version1/ChannelCodecs1.scala @@ -97,14 +97,12 @@ private[channel] object ChannelCodecs1 { closingTx => closingTx.tx ) - private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) - - private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = ( + private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = ( ("outPoint" | outPointCodec) :: ("txOut" | txOutCodec) :: - ("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy] + ("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly - val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly + val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo] private val defaultConfirmationTarget: Codec[ConfirmationTarget.Absolute] = provide(ConfirmationTarget.Absolute(BlockHeight(0))) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version2/ChannelCodecs2.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version2/ChannelCodecs2.scala index c85b07feff..8da17d47b9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version2/ChannelCodecs2.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version2/ChannelCodecs2.scala @@ -101,14 +101,12 @@ private[channel] object ChannelCodecs2 { val txCodec: Codec[Transaction] = lengthDelimited(bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d))) - private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) - - private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = ( + private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = ( ("outPoint" | outPointCodec) :: ("txOut" | txOutCodec) :: - ("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy] + ("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly - val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly + val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo] val outputInfoCodec: Codec[OutputInfo] = ( ("index" | uint32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version3/ChannelCodecs3.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version3/ChannelCodecs3.scala index 36521f3db7..344b98b6b0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version3/ChannelCodecs3.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version3/ChannelCodecs3.scala @@ -113,14 +113,12 @@ private[channel] object ChannelCodecs3 { val txCodec: Codec[Transaction] = lengthDelimited(bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d))) - private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) - - private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = ( + private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = ( ("outPoint" | outPointCodec) :: ("txOut" | txOutCodec) :: - ("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy] + ("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly - val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly + val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo] val outputInfoCodec: Codec[OutputInfo] = ( ("index" | uint32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4.scala index 44e406c647..cfa62ee6de 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4.scala @@ -115,21 +115,27 @@ private[channel] object ChannelCodecs4 { val scriptTreeAndInternalKey: Codec[ScriptTreeAndInternalKey] = (scriptTreeCodec :: xonlyPublicKey).as[ScriptTreeAndInternalKey] - private case class InputInfoEx(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector, redeemScriptOrScriptTree: Either[ByteVector, ScriptTreeAndInternalKey], dummy: Boolean) + private case class InputInfoEx(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector, redeemScriptOrScriptTree: Either[ByteVector, ScriptTreeAndInternalKey]) // To support the change from redeemScript to "either redeem script or script tree" while remaining backwards-compatible with the previous version 4 codec, we use // the redeem script itself as a left/write indicator: empty -> right, not empty -> left - private val inputInfoExCodec: Codec[InputInfoEx] = ( + private def scriptOrTreeCodec(redeemScript: ByteVector): Codec[Either[ByteVector, ScriptTreeAndInternalKey]] = either(provide(redeemScript.isEmpty), provide(redeemScript), scriptTreeAndInternalKey) + + private val inputInfoExCodec: Codec[InputInfoEx] = { ("outPoint" | outPointCodec) :: ("txOut" | txOutCodec) :: - (("redeemScript" | lengthDelimited(bytes)) >>:~ { redeemScript => - ("redeemScriptOrScriptTree" | either(provide(redeemScript.isEmpty), provide(redeemScript), scriptTreeAndInternalKey)) :: ("dummy" | provide(false)) - }) - ).as[InputInfoEx] + (("redeemScript" | lengthDelimited(bytes)) >>:~ { redeemScript => scriptOrTreeCodec(redeemScript).hlist }) + }.as[InputInfoEx] val inputInfoCodec: Codec[InputInfo] = inputInfoExCodec.xmap( - iex => InputInfo(iex.outPoint, iex.txOut, iex.redeemScriptOrScriptTree), - i => InputInfoEx(i.outPoint, i.txOut, i.redeemScriptOrScriptTree.swap.toOption.getOrElse(ByteVector.empty), i.redeemScriptOrScriptTree, false) + iex => iex.redeemScriptOrScriptTree match { + case Left(redeemScript) => InputInfo.SegwitInput(iex.outPoint, iex.txOut, redeemScript) + case Right(scriptTreeAndInternalKey) => InputInfo.TaprootInput(iex.outPoint, iex.txOut, scriptTreeAndInternalKey) + }, + i => i match { + case InputInfo.SegwitInput(_, _, redeemScript) => InputInfoEx(i.outPoint, i.txOut, redeemScript, Left(redeemScript)) + case InputInfo.TaprootInput(_, _, scriptTreeAndInternalKey) => InputInfoEx(i.outPoint, i.txOut, ByteVector.empty, Right(scriptTreeAndInternalKey)) + } ) val outputInfoCodec: Codec[OutputInfo] = ( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunderSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunderSpec.scala index 1577b1bcea..ff7591e67a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunderSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/publish/ReplaceableTxFunderSpec.scala @@ -42,7 +42,7 @@ class ReplaceableTxFunderSpec extends TestKitBaseClass with AnyFunSuiteLike { val commitInput = Funding.makeFundingInputInfo(randomTxId(), 1, 500 sat, PlaceHolderPubKey, PlaceHolderPubKey) val commitTx = Transaction( 2, - Seq(TxIn(commitInput.outPoint, commitInput.redeemScriptOrEmptyScript, 0, Scripts.witness2of2(PlaceHolderSig, PlaceHolderSig, PlaceHolderPubKey, PlaceHolderPubKey))), + Seq(TxIn(commitInput.outPoint, commitInput.redeemScript, 0, Scripts.witness2of2(PlaceHolderSig, PlaceHolderSig, PlaceHolderPubKey, PlaceHolderPubKey))), Seq(TxOut(330 sat, Script.pay2wsh(anchorScript))), 0 ) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TestVectorsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TestVectorsSpec.scala index ddedba2fad..336ce4cc23 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TestVectorsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/transactions/TestVectorsSpec.scala @@ -140,8 +140,8 @@ trait TestVectorsSpec extends AnyFunSuite with Logging { logger.info(s"remotekey: ${Remote.payment_privkey.publicKey}") logger.info(s"local_delayedkey: ${Local.delayed_payment_privkey.publicKey}") logger.info(s"local_revocation_key: ${Local.revocation_pubkey}") - logger.info(s"# funding wscript = ${commitmentInput.redeemScriptOrScriptTree}") - assert(commitmentInput.redeemScriptOrScriptTree == Left(hex"5221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae")) + logger.info(s"# funding wscript = ${commitmentInput.redeemScript}") + assert(commitmentInput.redeemScript == hex"5221023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb21030e9f7b623d2ccc7c9bd44d66d5ce21ce504c0acf6385a132cec6d3c39fa711c152ae") val paymentPreimages = Seq( ByteVector32(hex"0000000000000000000000000000000000000000000000000000000000000000"), @@ -250,7 +250,7 @@ trait TestVectorsSpec extends AnyFunSuite with Logging { case tx: HtlcSuccessTx => val localSig = tx.sign(Local.htlc_privkey, TxOwner.Local, commitmentFormat) val remoteSig = tx.sign(Remote.htlc_privkey, TxOwner.Remote, commitmentFormat) - val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.redeemScriptOrEmptyScript)) + val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.asInstanceOf[InputInfo.SegwitInput].redeemScript)) val preimage = paymentPreimages.find(p => Crypto.sha256(p) == tx.paymentHash).get val tx1 = Transactions.addSigs(tx, localSig, remoteSig, preimage, commitmentFormat) Transaction.correctlySpends(tx1.tx, Seq(commitTx.tx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) @@ -262,7 +262,7 @@ trait TestVectorsSpec extends AnyFunSuite with Logging { case tx: HtlcTimeoutTx => val localSig = tx.sign(Local.htlc_privkey, TxOwner.Local, commitmentFormat) val remoteSig = tx.sign(Remote.htlc_privkey, TxOwner.Remote, commitmentFormat) - val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.redeemScriptOrEmptyScript)) + val htlcIndex = htlcScripts.indexOf(Script.parse(tx.input.asInstanceOf[InputInfo.SegwitInput].redeemScript)) val tx1 = Transactions.addSigs(tx, localSig, remoteSig, commitmentFormat) Transaction.correctlySpends(tx1.tx, Seq(commitTx.tx), ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS) logger.info(s"# signature for output #${tx.input.outPoint.index} (htlc-timeout for htlc #$htlcIndex)") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4Spec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4Spec.scala index cc15f5289f..9f3349414c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4Spec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/internal/channel/version4/ChannelCodecs4Spec.scala @@ -252,4 +252,9 @@ class ChannelCodecs4Spec extends AnyFunSuite { } } + test("decode InputInfo backwards compatibility") { + val pub = PrivateKey(ByteVector.fromValidHex("01" * 32)).publicKey + val nonreg = inputInfoCodec.decode(hex"0x241368efd75e267a7463c9f81bb2b6e6812f02ffc325f57b6fe4a6773f2ff0ce882a0000001f10a400000000000016001479b000887626b294a914501a4cd226b58b2359831976a91479b000887626b294a914501a4cd226b58b23598388ac".bits).require.value + assert(nonreg == InputInfo.SegwitInput(OutPoint(TxId(ByteVector32.fromValidHex("88cef02f3f77a6e46f7bf525c3ff022f81e6b6b21bf8c963747a265ed7ef6813")), 42), TxOut(Satoshi(42000), Script.pay2wpkh(pub)), Script.write(Script.pay2pkh(pub)))) + } }