Skip to content

Commit

Permalink
preserve all PSBT fields when signing with SD
Browse files Browse the repository at this point in the history
fix lint and tests
  • Loading branch information
odudex committed Dec 26, 2024
1 parent 7e117a3 commit d467330
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 64 deletions.
41 changes: 2 additions & 39 deletions src/krux/pages/home_pages/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
from ...qr import FORMAT_NONE, FORMAT_PMOFN
from ...krux_settings import t, Settings
from ...format import replace_decimal_separator
from ...key import TYPE_SINGLESIG, P2WSH, P2TR


MAX_POLICY_COSIGNERS_DISPLAYED = 5
from ...key import TYPE_SINGLESIG


class Home(Page):
Expand Down Expand Up @@ -329,41 +326,7 @@ def sign_psbt(self):
not self.ctx.wallet.is_loaded()
and not self.ctx.wallet.key.policy_type == TYPE_SINGLESIG
):
from ...key import Key
from ...psbt import is_multisig

policy_str = "PSBT policy:\n"
policy_str += signer.policy["type"] + "\n"
if is_multisig(signer.policy):
policy_str += (
str(signer.policy["m"]) + " of " + str(signer.policy["n"]) + "\n"
)
fingerprints = []
for inp in signer.psbt.inputs:
# Do we need to loop through all the inputs or just one?
if signer.policy["type"] == P2WSH:
for pub in inp.bip32_derivations:
fingerprint_srt = Key.format_fingerprint(
inp.bip32_derivations[pub].fingerprint, True
)
if fingerprint_srt not in fingerprints:
if len(fingerprints) > MAX_POLICY_COSIGNERS_DISPLAYED:
fingerprints[-1] = "..."
break
fingerprints.append(fingerprint_srt)
elif signer.policy["type"] == P2TR:
for pub in inp.taproot_bip32_derivations:
_, derivation_path = inp.taproot_bip32_derivations[pub]
fingerprint_srt = Key.format_fingerprint(
derivation_path.fingerprint, True
)
if fingerprint_srt not in fingerprints:
if len(fingerprints) > MAX_POLICY_COSIGNERS_DISPLAYED:
fingerprints[-1] = "..."
break
fingerprints.append(fingerprint_srt)

policy_str += "\n".join(fingerprints)
policy_str = signer.psbt_policy_string()
self.ctx.display.clear()
self.ctx.display.draw_centered_text(policy_str)
if not self.prompt(t("Proceed?"), BOTTOM_PROMPT_LINE):
Expand Down
67 changes: 46 additions & 21 deletions src/krux/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
import gc
from embit.psbt import PSBT
from embit.psbt import PSBT, CompressMode
from embit.bip32 import HARDENED_INDEX
from ur.ur import UR
import urtypes
Expand All @@ -40,6 +40,8 @@
# We always uses thin spaces after the ₿ in this file
BTC_SYMBOL = "₿"

MAX_POLICY_COSIGNERS_DISPLAYED = 5


class Counter(dict):
"""Helper class for dict"""
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(self, wallet, psbt_data, qr_format, psbt_filename=None):
file_path = "/%s/%s" % (SD_PATH, psbt_filename)
try:
with open(file_path, "rb") as file:
self.psbt = PSBT.read_from(file, compress=1)
self.psbt = PSBT.read_from(file)
self.validate()
except:
try:
Expand All @@ -81,11 +83,12 @@ def __init__(self, wallet, psbt_data, qr_format, psbt_filename=None):
psbt_data = file.read()
self.psbt = PSBT.parse(base_decode(psbt_data, 64))
else:
# Legacy will fail to get policy from compressed PSBT
# so we load it uncompressed
# Try to load the PSBT in compressed mode
with open(file_path, "rb") as file:
file.seek(0) # Reset the file pointer to the beginning
self.psbt = PSBT.read_from(file)
self.psbt = PSBT.read_from(
file, compress=CompressMode.CLEAR_ALL
)
except Exception as e:
raise ValueError("Error loading PSBT file: %s" % e)
self.base_encoding = 64 # In case it is exported as QR code
Expand Down Expand Up @@ -513,10 +516,11 @@ def sign(self, trim=True):

trimmed_psbt = PSBT(self.psbt.tx)
for i, inp in enumerate(self.psbt.inputs):
# Copy the final_scriptwitness if present (for Taproot or other SegWit inputs)
# Copy the final_scriptwitness if present
if inp.final_scriptwitness:
trimmed_psbt.inputs[i].final_scriptwitness = inp.final_scriptwitness
# Copy any partial signatures (for multisig or other script types)

# Copy any partial signatures
if inp.partial_sigs:
trimmed_psbt.inputs[i].partial_sigs = inp.partial_sigs

Expand All @@ -536,20 +540,6 @@ def sign(self, trim=True):
if inp.witness_script:
trimmed_psbt.inputs[i].witness_script = inp.witness_script

# # # --- Taproot-specific fields ---

# # # Taproot BIP32 derivation paths (PSBT_IN_TAP_BIP32_DERIVATION)
# # if inp.taproot_bip32_derivations is not None:
# # trimmed_psbt.inputs[i].taproot_bip32_derivations = inp.taproot_bip32_derivations

# # # Internal key (PSBT_IN_TAP_INTERNAL_KEY)
# # if inp.taproot_internal_key is not None:
# # trimmed_psbt.inputs[i].taproot_internal_key = inp.taproot_internal_key

# # # Taproot leaf scripts (PSBT_IN_TAP_LEAF_SCRIPT)
# # if inp.taproot_scripts is not None:
# # trimmed_psbt.inputs[i].taproot_scripts = inp.taproot_scripts

self.psbt = trimmed_psbt

def psbt_qr(self):
Expand Down Expand Up @@ -606,6 +596,41 @@ def xpubs(self):
)
return xpubs

def psbt_policy_string(self):
"""Returns the policy string containing script type and cosigners' fingerprints"""

policy_str = "PSBT policy:\n"
policy_str += self.policy["type"] + "\n"
if is_multisig(self.policy):
policy_str += str(self.policy["m"]) + " of " + str(self.policy["n"]) + "\n"
fingerprints = []
for inp in self.psbt.inputs:
# Do we need to loop through all the inputs or just one?
if self.policy["type"] == P2WSH:
for pub in inp.bip32_derivations:
fingerprint_srt = Key.format_fingerprint(
inp.bip32_derivations[pub].fingerprint, True
)
if fingerprint_srt not in fingerprints:
if len(fingerprints) > MAX_POLICY_COSIGNERS_DISPLAYED:
fingerprints[-1] = "..."
break

Check warning on line 617 in src/krux/psbt.py

View check run for this annotation

Codecov / codecov/patch

src/krux/psbt.py#L616-L617

Added lines #L616 - L617 were not covered by tests
fingerprints.append(fingerprint_srt)
elif self.policy["type"] == P2TR:
for pub in inp.taproot_bip32_derivations:
_, derivation_path = inp.taproot_bip32_derivations[pub]
fingerprint_srt = Key.format_fingerprint(

Check warning on line 622 in src/krux/psbt.py

View check run for this annotation

Codecov / codecov/patch

src/krux/psbt.py#L619-L622

Added lines #L619 - L622 were not covered by tests
derivation_path.fingerprint, True
)
if fingerprint_srt not in fingerprints:
if len(fingerprints) > MAX_POLICY_COSIGNERS_DISPLAYED:
fingerprints[-1] = "..."
break
fingerprints.append(fingerprint_srt)

Check warning on line 629 in src/krux/psbt.py

View check run for this annotation

Codecov / codecov/patch

src/krux/psbt.py#L625-L629

Added lines #L625 - L629 were not covered by tests

policy_str += "\n".join(fingerprints)
return policy_str


def is_multisig(policy):
"""Returns a boolean indicating if the policy is a multisig"""
Expand Down
Loading

0 comments on commit d467330

Please sign in to comment.