Skip to content

Commit

Permalink
Use specific segwit and taproot input info types
Browse files Browse the repository at this point in the history
We now use specific subtypes for segwit inputs (which include a redeem script) and taproot inputs (which include a script tree and an internal key).
Older codecs have been modified to always return a SegwitInput.
v4 codec is modified and uses an empty redeem script as a marker to specify that a script tree is being used, which makes it compatible with the current v4 codec.
  • Loading branch information
sstone committed Jan 6, 2025
1 parent 8f10d79 commit a842915
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,10 @@ case class Commitments(params: ChannelParams,
val localFundingKey = keyManager.fundingPublicKey(params.localParams.fundingKeyPath, commitment.fundingTxIndex).publicKey
val remoteFundingKey = commitment.remoteFundingPubKey
val fundingScript = Script.write(Scripts.multiSig2of2(localFundingKey, remoteFundingKey))
commitment.commitInput.redeemScriptOrScriptTree == Left(fundingScript)
commitment.commitInput match {
case InputInfo.SegwitInput(_, _, redeemScript) => redeemScript == fundingScript
case _ => false
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ object Helpers {

def makeFundingPubKeyScript(localFundingKey: PublicKey, remoteFundingKey: PublicKey): ByteVector = write(pay2wsh(multiSig2of2(localFundingKey, remoteFundingKey)))

def makeFundingInputInfo(fundingTxId: TxId, fundingTxOutputIndex: Int, fundingSatoshis: Satoshi, fundingPubkey1: PublicKey, fundingPubkey2: PublicKey): InputInfo = {
def makeFundingInputInfo(fundingTxId: TxId, fundingTxOutputIndex: Int, fundingSatoshis: Satoshi, fundingPubkey1: PublicKey, fundingPubkey2: PublicKey): InputInfo.SegwitInput = {
val fundingScript = multiSig2of2(fundingPubkey1, fundingPubkey2)
val fundingTxOut = TxOut(fundingSatoshis, pay2wsh(fundingScript))
InputInfo(OutPoint(fundingTxId, fundingTxOutputIndex), fundingTxOut, write(fundingScript))
InputInfo.SegwitInput(OutPoint(fundingTxId, fundingTxOutputIndex), fundingTxOut, write(fundingScript))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,8 @@ private class ReplaceableTxFunder(nodeParams: NodeParams,
import fr.acinq.bitcoin.scalacompat.KotlinUtils._

// We create a PSBT with the non-wallet input already signed:
val witnessScript = locallySignedTx.txInfo.input.redeemScriptOrScriptTree match {
case Left(redeemScript) => fr.acinq.bitcoin.Script.parse(redeemScript)
val witnessScript = locallySignedTx.txInfo.input match {
case InputInfo.SegwitInput(_, _, redeemScript) => fr.acinq.bitcoin.Script.parse(redeemScript)
case _ => null
}
val psbt = new Psbt(locallySignedTx.txInfo.tx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import fr.acinq.eclair._
import fr.acinq.eclair.blockchain.fee.{ConfirmationTarget, FeeratePerKw}
import fr.acinq.eclair.transactions.CommitmentOutput._
import fr.acinq.eclair.transactions.Scripts._
import fr.acinq.eclair.transactions.Transactions.InputInfo.SegwitInput
import fr.acinq.eclair.wire.protocol.UpdateAddHtlc
import scodec.bits.ByteVector

Expand Down Expand Up @@ -102,14 +103,18 @@ object Transactions {
val publicKeyScript: ByteVector = Script.write(Script.pay2tr(internalKey, Some(scriptTree)))
}

case class InputInfo(outPoint: OutPoint, txOut: TxOut, redeemScriptOrScriptTree: Either[ByteVector, ScriptTreeAndInternalKey]) {
val redeemScriptOrEmptyScript: ByteVector = redeemScriptOrScriptTree.swap.getOrElse(ByteVector.empty) // TODO: use the actual script tree for taproot transactions, once we implement them
sealed trait InputInfo {
val outPoint: OutPoint
val txOut: TxOut
}

object InputInfo {
def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) = new InputInfo(outPoint, txOut, Left(redeemScript))
def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: Seq[ScriptElt]) = new InputInfo(outPoint, txOut, Left(Script.write(redeemScript)))
def apply(outPoint: OutPoint, txOut: TxOut, scriptTree: ScriptTreeAndInternalKey) = new InputInfo(outPoint, txOut, Right(scriptTree))
case class SegwitInput(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector) extends InputInfo
case class TaprootInput(outPoint: OutPoint, txOut: TxOut, scriptTreeAndInternalKey: ScriptTreeAndInternalKey) extends InputInfo

def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector): SegwitInput = SegwitInput(outPoint, txOut, redeemScript)
def apply(outPoint: OutPoint, txOut: TxOut, redeemScript: Seq[ScriptElt]): SegwitInput = SegwitInput(outPoint, txOut, Script.write(redeemScript))
def apply(outPoint: OutPoint, txOut: TxOut, scriptTree: ScriptTreeAndInternalKey): TaprootInput = TaprootInput(outPoint, txOut, scriptTree)
}

/** Owner of a given transaction (local/remote). */
Expand Down Expand Up @@ -138,24 +143,29 @@ object Transactions {
sign(key, sighash(txOwner, commitmentFormat))
}

def sign(key: PrivateKey, sighashType: Int): ByteVector64 = {
// NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
// signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint)
val sigDER = Transaction.signInput(tx, inputIndex, input.redeemScriptOrEmptyScript, sighashType, input.txOut.amount, SIGVERSION_WITNESS_V0, key)
val sig64 = Crypto.der2compact(sigDER)
sig64
def sign(key: PrivateKey, sighashType: Int): ByteVector64 = input match {
case _:InputInfo.TaprootInput => ByteVector64.Zeroes
case InputInfo.SegwitInput(outPoint, txOut, redeemScript) =>
// NB: the tx may have multiple inputs, we will only sign the one provided in txinfo.input. Bear in mind that the
// signature will be invalidated if other inputs are added *afterwards* and sighashType was SIGHASH_ALL.
val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint)
val sigDER = Transaction.signInput(tx, inputIndex, redeemScript, sighashType, txOut.amount, SIGVERSION_WITNESS_V0, key)
val sig64 = Crypto.der2compact(sigDER)
sig64
}

def checkSig(sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = {
val sighash = this.sighash(txOwner, commitmentFormat)
val inputIndex = tx.txIn.indexWhere(_.outPoint == input.outPoint)
if (inputIndex >= 0) {
val data = Transaction.hashForSigning(tx, inputIndex, input.redeemScriptOrEmptyScript, sighash, input.txOut.amount, SIGVERSION_WITNESS_V0)
Crypto.verifySignature(data, sig, pubKey)
} else {
false
}
def checkSig(sig: ByteVector64, pubKey: PublicKey, txOwner: TxOwner, commitmentFormat: CommitmentFormat): Boolean = input match {

case _:InputInfo.TaprootInput => false
case InputInfo.SegwitInput(outPoint, txOut, redeemScript) =>
val sighash = this.sighash(txOwner, commitmentFormat)
val inputIndex = tx.txIn.indexWhere(_.outPoint == outPoint)
if (inputIndex >= 0) {
val data = Transaction.hashForSigning(tx, inputIndex, redeemScript, sighash, txOut.amount, SIGVERSION_WITNESS_V0)
Crypto.verifySignature(data, sig, pubKey)
} else {
false
}
}
}

Expand Down Expand Up @@ -983,64 +993,86 @@ object Transactions {
commitTx.copy(tx = commitTx.tx.updateWitness(0, witness))
}

def addSigs(mainPenaltyTx: MainPenaltyTx, revocationSig: ByteVector64): MainPenaltyTx = {
val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, mainPenaltyTx.input.redeemScriptOrEmptyScript)
mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0, witness))
def addSigs(mainPenaltyTx: MainPenaltyTx, revocationSig: ByteVector64): MainPenaltyTx = mainPenaltyTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript)
mainPenaltyTx.copy(tx = mainPenaltyTx.tx.updateWitness(0, witness))
case _ => mainPenaltyTx
}

def addSigs(htlcPenaltyTx: HtlcPenaltyTx, revocationSig: ByteVector64, revocationPubkey: PublicKey): HtlcPenaltyTx = {
val witness = Scripts.witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, htlcPenaltyTx.input.redeemScriptOrEmptyScript)
htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0, witness))
def addSigs(htlcPenaltyTx: HtlcPenaltyTx, revocationSig: ByteVector64, revocationPubkey: PublicKey): HtlcPenaltyTx = htlcPenaltyTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = Scripts.witnessHtlcWithRevocationSig(revocationSig, revocationPubkey, redeemScript)
htlcPenaltyTx.copy(tx = htlcPenaltyTx.tx.updateWitness(0, witness))
case _ => htlcPenaltyTx
}

def addSigs(htlcSuccessTx: HtlcSuccessTx, localSig: ByteVector64, remoteSig: ByteVector64, paymentPreimage: ByteVector32, commitmentFormat: CommitmentFormat): HtlcSuccessTx = {
val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, htlcSuccessTx.input.redeemScriptOrEmptyScript, commitmentFormat)
htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0, witness))
def addSigs(htlcSuccessTx: HtlcSuccessTx, localSig: ByteVector64, remoteSig: ByteVector64, paymentPreimage: ByteVector32, commitmentFormat: CommitmentFormat): HtlcSuccessTx = htlcSuccessTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessHtlcSuccess(localSig, remoteSig, paymentPreimage, redeemScript, commitmentFormat)
htlcSuccessTx.copy(tx = htlcSuccessTx.tx.updateWitness(0, witness))
case _ => htlcSuccessTx
}

def addSigs(htlcTimeoutTx: HtlcTimeoutTx, localSig: ByteVector64, remoteSig: ByteVector64, commitmentFormat: CommitmentFormat): HtlcTimeoutTx = {
val witness = witnessHtlcTimeout(localSig, remoteSig, htlcTimeoutTx.input.redeemScriptOrEmptyScript, commitmentFormat)
htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0, witness))
def addSigs(htlcTimeoutTx: HtlcTimeoutTx, localSig: ByteVector64, remoteSig: ByteVector64, commitmentFormat: CommitmentFormat): HtlcTimeoutTx = htlcTimeoutTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessHtlcTimeout(localSig, remoteSig, redeemScript, commitmentFormat)
htlcTimeoutTx.copy(tx = htlcTimeoutTx.tx.updateWitness(0, witness))
case _ => htlcTimeoutTx
}

def addSigs(claimHtlcSuccessTx: ClaimHtlcSuccessTx, localSig: ByteVector64, paymentPreimage: ByteVector32): ClaimHtlcSuccessTx = {
val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, claimHtlcSuccessTx.input.redeemScriptOrEmptyScript)
claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0, witness))
def addSigs(claimHtlcSuccessTx: ClaimHtlcSuccessTx, localSig: ByteVector64, paymentPreimage: ByteVector32): ClaimHtlcSuccessTx = claimHtlcSuccessTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessClaimHtlcSuccessFromCommitTx(localSig, paymentPreimage, redeemScript)
claimHtlcSuccessTx.copy(tx = claimHtlcSuccessTx.tx.updateWitness(0, witness))
case _ => claimHtlcSuccessTx
}

def addSigs(claimHtlcTimeoutTx: ClaimHtlcTimeoutTx, localSig: ByteVector64): ClaimHtlcTimeoutTx = {
val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, claimHtlcTimeoutTx.input.redeemScriptOrEmptyScript)
claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0, witness))
def addSigs(claimHtlcTimeoutTx: ClaimHtlcTimeoutTx, localSig: ByteVector64): ClaimHtlcTimeoutTx = claimHtlcTimeoutTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessClaimHtlcTimeoutFromCommitTx(localSig, redeemScript)
claimHtlcTimeoutTx.copy(tx = claimHtlcTimeoutTx.tx.updateWitness(0, witness))
case _ => claimHtlcTimeoutTx
}

def addSigs(claimP2WPKHOutputTx: ClaimP2WPKHOutputTx, localPaymentPubkey: PublicKey, localSig: ByteVector64): ClaimP2WPKHOutputTx = {
val witness = ScriptWitness(Seq(der(localSig), localPaymentPubkey.value))
claimP2WPKHOutputTx.copy(tx = claimP2WPKHOutputTx.tx.updateWitness(0, witness))
}

def addSigs(claimRemoteDelayedOutputTx: ClaimRemoteDelayedOutputTx, localSig: ByteVector64): ClaimRemoteDelayedOutputTx = {
val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, claimRemoteDelayedOutputTx.input.redeemScriptOrEmptyScript)
claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0, witness))
def addSigs(claimRemoteDelayedOutputTx: ClaimRemoteDelayedOutputTx, localSig: ByteVector64): ClaimRemoteDelayedOutputTx = claimRemoteDelayedOutputTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessClaimToRemoteDelayedFromCommitTx(localSig, redeemScript)
claimRemoteDelayedOutputTx.copy(tx = claimRemoteDelayedOutputTx.tx.updateWitness(0, witness))
case _ => claimRemoteDelayedOutputTx
}

def addSigs(claimDelayedOutputTx: ClaimLocalDelayedOutputTx, localSig: ByteVector64): ClaimLocalDelayedOutputTx = {
val witness = witnessToLocalDelayedAfterDelay(localSig, claimDelayedOutputTx.input.redeemScriptOrEmptyScript)
claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0, witness))
def addSigs(claimDelayedOutputTx: ClaimLocalDelayedOutputTx, localSig: ByteVector64): ClaimLocalDelayedOutputTx = claimDelayedOutputTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript)
claimDelayedOutputTx.copy(tx = claimDelayedOutputTx.tx.updateWitness(0, witness))
case _ => claimDelayedOutputTx
}

def addSigs(htlcDelayedTx: HtlcDelayedTx, localSig: ByteVector64): HtlcDelayedTx = {
val witness = witnessToLocalDelayedAfterDelay(localSig, htlcDelayedTx.input.redeemScriptOrEmptyScript)
htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0, witness))
def addSigs(htlcDelayedTx: HtlcDelayedTx, localSig: ByteVector64): HtlcDelayedTx = htlcDelayedTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessToLocalDelayedAfterDelay(localSig, redeemScript)
htlcDelayedTx.copy(tx = htlcDelayedTx.tx.updateWitness(0, witness))
case _ => htlcDelayedTx
}

def addSigs(claimAnchorOutputTx: ClaimLocalAnchorOutputTx, localSig: ByteVector64): ClaimLocalAnchorOutputTx = {
val witness = witnessAnchor(localSig, claimAnchorOutputTx.input.redeemScriptOrEmptyScript)
claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0, witness))
def addSigs(claimAnchorOutputTx: ClaimLocalAnchorOutputTx, localSig: ByteVector64): ClaimLocalAnchorOutputTx = claimAnchorOutputTx.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = witnessAnchor(localSig, redeemScript)
claimAnchorOutputTx.copy(tx = claimAnchorOutputTx.tx.updateWitness(0, witness))
case _ => claimAnchorOutputTx
}

def addSigs(claimHtlcDelayedPenalty: ClaimHtlcDelayedOutputPenaltyTx, revocationSig: ByteVector64): ClaimHtlcDelayedOutputPenaltyTx = {
val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, claimHtlcDelayedPenalty.input.redeemScriptOrEmptyScript)
claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0, witness))
def addSigs(claimHtlcDelayedPenalty: ClaimHtlcDelayedOutputPenaltyTx, revocationSig: ByteVector64): ClaimHtlcDelayedOutputPenaltyTx = claimHtlcDelayedPenalty.input match {
case InputInfo.SegwitInput(_, _, redeemScript) =>
val witness = Scripts.witnessToLocalDelayedWithRevocationSig(revocationSig, redeemScript)
claimHtlcDelayedPenalty.copy(tx = claimHtlcDelayedPenalty.tx.updateWitness(0, witness))
case _ => claimHtlcDelayedPenalty
}

def addSigs(closingTx: ClosingTx, localFundingPubkey: PublicKey, remoteFundingPubkey: PublicKey, localSig: ByteVector64, remoteSig: ByteVector64): ClosingTx = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,12 @@ private[channel] object ChannelCodecs0 {
closingTx => closingTx.tx
)

private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector)

private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = (
private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = (
("outPoint" | outPointCodec) ::
("txOut" | txOutCodec) ::
("redeemScript" | varsizebinarydata)).as[InputInfoLegacy]
("redeemScript" | varsizebinarydata)).as[InputInfo.SegwitInput].decodeOnly

val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.map(legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript))).decodeOnly
val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo]

private val defaultConfirmationTarget: Codec[ConfirmationTarget.Absolute] = provide(ConfirmationTarget.Absolute(BlockHeight(0)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@ private[channel] object ChannelCodecs1 {
closingTx => closingTx.tx
)

private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector)

private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = (
private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = (
("outPoint" | outPointCodec) ::
("txOut" | txOutCodec) ::
("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy]
("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly

val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly
val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo]

private val defaultConfirmationTarget: Codec[ConfirmationTarget.Absolute] = provide(ConfirmationTarget.Absolute(BlockHeight(0)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,12 @@ private[channel] object ChannelCodecs2 {

val txCodec: Codec[Transaction] = lengthDelimited(bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d)))

private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector)

private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = (
private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = (
("outPoint" | outPointCodec) ::
("txOut" | txOutCodec) ::
("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy]
("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly

val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly
val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo]

val outputInfoCodec: Codec[OutputInfo] = (
("index" | uint32) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,12 @@ private[channel] object ChannelCodecs3 {

val txCodec: Codec[Transaction] = lengthDelimited(bytes.xmap(d => Transaction.read(d.toArray), d => Transaction.write(d)))

private case class InputInfoLegacy(outPoint: OutPoint, txOut: TxOut, redeemScript: ByteVector)

private val inputInfoLegacyCodec: Codec[InputInfoLegacy] = (
private val legacyInputInfoCodec: Codec[InputInfo.SegwitInput] = (
("outPoint" | outPointCodec) ::
("txOut" | txOutCodec) ::
("redeemScript" | lengthDelimited(bytes))).as[InputInfoLegacy]
("redeemScript" | lengthDelimited(bytes))).as[InputInfo.SegwitInput].decodeOnly

val inputInfoCodec: Codec[InputInfo] = inputInfoLegacyCodec.xmap[InputInfo](legacy => InputInfo(legacy.outPoint, legacy.txOut, Left(legacy.redeemScript)), _ => ???).decodeOnly
val inputInfoCodec: Codec[InputInfo] = legacyInputInfoCodec.upcast[InputInfo]

val outputInfoCodec: Codec[OutputInfo] = (
("index" | uint32) ::
Expand Down
Loading

0 comments on commit a842915

Please sign in to comment.