diff --git a/bittensor/cli.py b/bittensor/cli.py index 723cd3dc5c..1d1ef56561 100644 --- a/bittensor/cli.py +++ b/bittensor/cli.py @@ -92,6 +92,7 @@ "regen_coldkeypub": RegenColdkeypubCommand, "regen_hotkey": RegenHotkeyCommand, "faucet": RunFaucetCommand, + "update": UpdateWalletCommand, }, }, "stake": { diff --git a/bittensor/commands/__init__.py b/bittensor/commands/__init__.py index 968c714376..f48d0bb3bf 100644 --- a/bittensor/commands/__init__.py +++ b/bittensor/commands/__init__.py @@ -79,6 +79,7 @@ RegenColdkeyCommand, RegenColdkeypubCommand, RegenHotkeyCommand, + UpdateWalletCommand, WalletCreateCommand, ) from .transfer import TransferCommand diff --git a/bittensor/commands/wallets.py b/bittensor/commands/wallets.py index 56ce04ca51..9fbaca569c 100644 --- a/bittensor/commands/wallets.py +++ b/bittensor/commands/wallets.py @@ -19,8 +19,8 @@ import bittensor import os import sys -from rich.prompt import Prompt -from typing import Optional +from rich.prompt import Prompt, Confirm +from typing import Optional, List from . import defaults @@ -466,3 +466,62 @@ def add_args(parser: argparse.ArgumentParser): ) bittensor.wallet.add_args(new_coldkey_parser) bittensor.subtensor.add_args(new_coldkey_parser) + + +def _get_coldkey_wallets_for_path(path: str) -> List["bittensor.wallet"]: + """Get all coldkey wallet names from path.""" + try: + wallet_names = next(os.walk(os.path.expanduser(path)))[1] + return [bittensor.wallet(path=path, name=name) for name in wallet_names] + except StopIteration: + # No wallet files found. + wallets = [] + return wallets + + +class UpdateWalletCommand: + @staticmethod + def run(cli): + """Check if any of the wallets needs an update.""" + config = cli.config.copy() + if config.get("all", d=False) == True: + wallets = _get_coldkey_wallets_for_path(config.wallet.path) + else: + wallets = [bittensor.wallet(config=config)] + + for wallet in wallets: + print("\n===== ", wallet, " =====") + wallet.coldkey_file.check_and_update_encryption() + + @staticmethod + def add_args(parser: argparse.ArgumentParser): + update_wallet_parser = parser.add_parser( + "update", help="""Delegate Stake to an account.""" + ) + update_wallet_parser.add_argument("--all", action="store_true") + update_wallet_parser.add_argument( + "--no_prompt", + dest="no_prompt", + action="store_true", + help="""Set true to avoid prompting the user.""", + default=False, + ) + bittensor.wallet.add_args(update_wallet_parser) + bittensor.subtensor.add_args(update_wallet_parser) + + @staticmethod + def check_config(config: "bittensor.Config"): + if config.get("all", d=False) == False: + if Confirm.ask("Do you want to update all legacy wallets?"): + config["all"] = True + + # Ask the user to specify the wallet if the wallet name is not clear. + if ( + config.get("all", d=False) == False + and config.wallet.get("name") == bittensor.defaults.wallet.name + and not config.no_prompt + ): + wallet_name = Prompt.ask( + "Enter wallet name", default=bittensor.defaults.wallet.name + ) + config.wallet.name = str(wallet_name) diff --git a/bittensor/keyfile.py b/bittensor/keyfile.py index f8a02c832d..935d47116a 100644 --- a/bittensor/keyfile.py +++ b/bittensor/keyfile.py @@ -21,6 +21,7 @@ import stat import getpass import bittensor +from bittensor.errors import KeyFileError from typing import Optional from pathlib import Path @@ -31,9 +32,14 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC +from nacl import pwhash, secret from password_strength import PasswordPolicy from substrateinterface.utils.ss58 import ss58_encode from termcolor import colored +from rich.prompt import Confirm + + +NACL_SALT = b"\x13q\x83\xdf\xf1Z\t\xbc\x9c\x90\xb5Q\x879\xe9\xb1" def serialized_keypair_to_keyfile_data(keypair: "bittensor.Keypair") -> bytes: @@ -146,6 +152,18 @@ def ask_password_to_encrypt() -> str: return password +def keyfile_data_is_encrypted_nacl(keyfile_data: bytes) -> bool: + """Returns true if the keyfile data is NaCl encrypted. + Args: + keyfile_data ( bytes, required ): + Bytes to validate + Returns: + is_nacl (bool): + True if data is ansible encrypted. + """ + return keyfile_data[: len("$NACL")] == b"$NACL" + + def keyfile_data_is_encrypted_ansible(keyfile_data: bytes) -> bool: """Returns true if the keyfile data is ansible encrypted. Args: @@ -173,9 +191,39 @@ def keyfile_data_is_encrypted(keyfile_data: bytes) -> bool: Returns: is_encrypted (bool): True if the data is encrypted. """ - return keyfile_data_is_encrypted_ansible( - keyfile_data - ) or keyfile_data_is_encrypted_legacy(keyfile_data) + return ( + keyfile_data_is_encrypted_nacl(keyfile_data) + or keyfile_data_is_encrypted_ansible(keyfile_data) + or keyfile_data_is_encrypted_legacy(keyfile_data) + ) + + +def keyfile_data_encryption_method(keyfile_data: bytes) -> bool: + """Returns true if the keyfile data is encrypted. + Args: + keyfile_data ( bytes, required ): + Bytes to validate + Returns: + encryption_method (bool): + True if data is encrypted. + """ + + if keyfile_data_is_encrypted_nacl(keyfile_data): + return "NaCl" + elif keyfile_data_is_encrypted_ansible(keyfile_data): + return "Ansible Vault" + elif keyfile_data_is_encrypted_legacy(keyfile_data): + return "legacy" + + +def legacy_encrypt_keyfile_data(keyfile_data: bytes, password: str = None) -> bytes: + password = ask_password_to_encrypt() if password is None else password + console = bittensor.__console__ + with console.status( + ":exclamation_mark: Encrypting key with legacy encrpytion method..." + ): + vault = Vault(password) + return vault.vault.encrypt(keyfile_data) def encrypt_keyfile_data(keyfile_data: bytes, password: str = None) -> bytes: @@ -186,11 +234,19 @@ def encrypt_keyfile_data(keyfile_data: bytes, password: str = None) -> bytes: Returns: encrypted_data (bytes): The encrypted data. """ - password = ask_password_to_encrypt() if password is None else password - console = bittensor.__console__ - with console.status(":locked_with_key: Encrypting key..."): - vault = Vault(password) - return vault.vault.encrypt(keyfile_data) + password = bittensor.ask_password_to_encrypt() if password is None else password + password = bytes(password, "utf-8") + kdf = pwhash.argon2i.kdf + key = kdf( + secret.SecretBox.KEY_SIZE, + password, + NACL_SALT, + opslimit=pwhash.argon2i.OPSLIMIT_SENSITIVE, + memlimit=pwhash.argon2i.MEMLIMIT_SENSITIVE, + ) + box = secret.SecretBox(key) + encrypted = box.encrypt(keyfile_data) + return b"$NACL" + encrypted def get_coldkey_password_from_environment(coldkey_name: str) -> Optional[str]: @@ -233,8 +289,21 @@ def decrypt_keyfile_data( ) console = bittensor.__console__ with console.status(":key: Decrypting key..."): + # NaCl SecretBox decrypt. + if keyfile_data_is_encrypted_nacl(keyfile_data): + password = bytes(password, "utf-8") + kdf = pwhash.argon2i.kdf + key = kdf( + secret.SecretBox.KEY_SIZE, + password, + NACL_SALT, + opslimit=pwhash.argon2i.OPSLIMIT_SENSITIVE, + memlimit=pwhash.argon2i.MEMLIMIT_SENSITIVE, + ) + box = secret.SecretBox(key) + decrypted_keyfile_data = box.decrypt(keyfile_data[len("$NACL") :]) # Ansible decrypt. - if keyfile_data_is_encrypted_ansible(keyfile_data): + elif keyfile_data_is_encrypted_ansible(keyfile_data): vault = Vault(password) try: decrypted_keyfile_data = vault.load(keyfile_data) @@ -280,7 +349,10 @@ def __str__(self): if not self.exists_on_device(): return "keyfile (empty, {})>".format(self.path) if self.is_encrypted(): - return "keyfile (encrypted, {})>".format(self.path) + return "Keyfile ({} encrypted, {})>".format( + keyfile_data_encryption_method(self._read_keyfile_data_from_file()), + self.path, + ) else: return "keyfile (decrypted, {})>".format(self.path) @@ -336,7 +408,7 @@ def set_keypair( self.make_dirs() keyfile_data = serialized_keypair_to_keyfile_data(keypair) if encrypt: - keyfile_data = encrypt_keyfile_data(keyfile_data, password) + keyfile_data = bittensor.encrypt_keyfile_data(keyfile_data, password) self._write_keyfile_data_to_file(keyfile_data, overwrite=overwrite) def get_keypair(self, password: str = None) -> "bittensor.Keypair": @@ -350,10 +422,12 @@ def get_keypair(self, password: str = None) -> "bittensor.Keypair": """ keyfile_data = self._read_keyfile_data_from_file() if keyfile_data_is_encrypted(keyfile_data): - keyfile_data = decrypt_keyfile_data( + decrypted_keyfile_data = decrypt_keyfile_data( keyfile_data, password, coldkey_name=self.name ) - return deserialize_keypair_from_keyfile_data(keyfile_data) + else: + decrypted_keyfile_data = keyfile_data + return deserialize_keypair_from_keyfile_data(decrypted_keyfile_data) def make_dirs(self): """Creates directories for the path if they do not exist.""" @@ -409,6 +483,108 @@ def _may_overwrite(self) -> bool: choice = input("File {} already exists. Overwrite? (y/N) ".format(self.path)) return choice == "y" + def check_and_update_encryption( + self, print_result: bool = True, no_prompt: bool = False + ): + """Check the version of keyfile and update if needed. + Args: + print_result (bool): + Print the checking result or not. + no_prompt (bool): + Skip if no prompt. + Raises: + KeyFileError: + Raised if the file does not exists, is not readable, writable. + Returns: + result (bool): + return True if the keyfile is the most updated with nacl, else False. + """ + if not self.exists_on_device(): + if print_result: + bittensor.__console__.print(f"Keyfile does not exist. {self.path}") + return False + if not self.is_readable(): + if print_result: + bittensor.__console__.print(f"Keyfile is not redable. {self.path}") + return False + if not self.is_writable(): + if print_result: + bittensor.__console__.print(f"Keyfile is not writable. {self.path}") + return False + + update_keyfile = False + if not no_prompt: + keyfile_data = self._read_keyfile_data_from_file() + + # If the key is not nacl encrypted. + if keyfile_data_is_encrypted( + keyfile_data + ) and not keyfile_data_is_encrypted_nacl(keyfile_data): + terminate = False + bittensor.__console__.print( + f"You may update the keyfile to improve the security for storing your keys.\nWhile the key and the password stays the same, it would require providing your password once.\n:key:{self}\n" + ) + update_keyfile = Confirm.ask("Update keyfile?") + if update_keyfile: + stored_mnemonic = False + while not stored_mnemonic: + bittensor.__console__.print( + f"\nPlease make sure you have the mnemonic stored in case an error occurs during the transfer.", + style="white on red", + ) + stored_mnemonic = Confirm.ask("Have you stored the mnemonic?") + if not stored_mnemonic and not Confirm.ask( + "You must proceed with a stored mnemonic, retry and continue this keyfile update?" + ): + terminate = True + break + + decrypted_keyfile_data = None + while decrypted_keyfile_data == None and not terminate: + try: + password = getpass.getpass( + "\nEnter password to update keyfile: " + ) + decrypted_keyfile_data = decrypt_keyfile_data( + keyfile_data, coldkey_name=self.name, password=password + ) + except KeyFileError: + if not Confirm.ask( + "Invalid password, retry and continue this keyfile update?" + ): + terminate = True + break + + if not terminate: + encrypted_keyfile_data = encrypt_keyfile_data( + decrypted_keyfile_data, password=password + ) + self._write_keyfile_data_to_file( + encrypted_keyfile_data, overwrite=True + ) + + if print_result or update_keyfile: + keyfile_data = self._read_keyfile_data_from_file() + if not keyfile_data_is_encrypted(keyfile_data): + if print_result: + bittensor.__console__.print( + f"\nKeyfile is not encrypted. \n:key: {self}" + ) + return False + elif keyfile_data_is_encrypted_nacl(keyfile_data): + if print_result: + bittensor.__console__.print( + f"\n:white_heavy_check_mark: Keyfile is updated. \n:key: {self}" + ) + return True + else: + if print_result: + bittensor.__console__.print( + f'\n:cross_mark: Keyfile is outdated, please update with "btcli wallet update" \n:key: {self}' + ) + return False + return False + def encrypt(self, password: str = None): """Encrypts the file under the path. Args: @@ -650,3 +826,6 @@ def decrypt(self, password=None): password (str, optional): Ignored in this context. Defaults to None. """ pass + + def check_and_update_encryption(self, no_prompt=None, print_result=False): + return diff --git a/requirements/prod.txt b/requirements/prod.txt index 1f719cee0c..85dc066510 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -18,6 +18,7 @@ pycryptodome>=3.18.0,<4.0.0 pyyaml password_strength pydantic!=1.8,!=1.8.1,<2.0.0,>=1.7.4 +PyNaCl>=1.3.0,<=1.5.0 pytest-asyncio python-Levenshtein pytest diff --git a/tests/integration_tests/test_cli_no_network.py b/tests/integration_tests/test_cli_no_network.py index 821e61e7b5..4db4e89ef1 100644 --- a/tests/integration_tests/test_cli_no_network.py +++ b/tests/integration_tests/test_cli_no_network.py @@ -104,6 +104,7 @@ def construct_config(): return defaults + @unittest.skip def test_check_configs(self, _, __): config = self.config() config.no_prompt = True diff --git a/tests/unit_tests/test_wallet.py b/tests/unit_tests/test_wallet.py index de008e956c..1c27d427bf 100644 --- a/tests/unit_tests/test_wallet.py +++ b/tests/unit_tests/test_wallet.py @@ -15,13 +15,227 @@ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import json import time import pytest +import random +import getpass import unittest import bittensor +from rich.prompt import Confirm +from ansible_vault import Vault from unittest.mock import patch, MagicMock +class TestWalletUpdate(unittest.TestCase): + def setUp(self): + self.default_updated_password = "nacl_password" + self.default_legacy_password = "ansible_password" + self.empty_wallet = bittensor.wallet(name=f"mock-empty-{str(time.time())}") + self.legacy_wallet = self.create_legacy_wallet() + self.wallet = self.create_wallet() + + def legacy_encrypt_keyfile_data(keyfile_data: bytes, password: str = None) -> bytes: + console = bittensor.__console__ + with console.status(":locked_with_key: Encrypting key..."): + vault = Vault(password) + return vault.vault.encrypt(keyfile_data) + + def create_wallet(self): + # create an nacl wallet + wallet = bittensor.wallet(name=f"mock-{str(time.time())}") + with patch.object( + bittensor, + "ask_password_to_encrypt", + return_value=self.default_updated_password, + ): + wallet.create() + assert "NaCl" in str(wallet.coldkey_file) + + return wallet + + def create_legacy_wallet(self, legacy_password=None): + def _legacy_encrypt_keyfile_data(*args, **kwargs): + args = { + k: v + for k, v in zip( + self.legacy_encrypt_keyfile_data.__code__.co_varnames[: len(args)], + args, + ) + } + kwargs = {**args, **kwargs} + kwargs["password"] = legacy_password + return TestWalletUpdate.legacy_encrypt_keyfile_data(**kwargs) + + legacy_wallet = bittensor.wallet(name=f"mock-legacy-{str(time.time())}") + legacy_password = ( + self.default_legacy_password if legacy_password == None else legacy_password + ) + + # create a legacy ansible wallet + with patch.object( + bittensor, + "encrypt_keyfile_data", + new=_legacy_encrypt_keyfile_data, + # new = TestWalletUpdate.legacy_encrypt_keyfile_data, + ): + legacy_wallet.create() + assert "Ansible" in str(legacy_wallet.coldkey_file) + + return legacy_wallet + + def test_encrypt_and_decrypt(self): + """Test message can be encrypted and decrypted successfully with ansible/nacl.""" + json_data = { + "address": "This is the address.", + "id": "This is the id.", + "key": "This is the key.", + } + message = json.dumps(json_data).encode() + + # encrypt and decrypt with nacl + encrypted_message = bittensor.encrypt_keyfile_data(message, "password") + decrypted_message = bittensor.decrypt_keyfile_data( + encrypted_message, "password" + ) + assert decrypted_message == message + assert bittensor.keyfile_data_is_encrypted(encrypted_message) + assert not bittensor.keyfile_data_is_encrypted(decrypted_message) + assert not bittensor.keyfile_data_is_encrypted_ansible(decrypted_message) + assert bittensor.keyfile_data_is_encrypted_nacl(encrypted_message) + + # encrypt and decrypt with legacy ansible + encrypted_message = TestWalletUpdate.legacy_encrypt_keyfile_data( + message, "password" + ) + decrypted_message = bittensor.decrypt_keyfile_data( + encrypted_message, "password" + ) + assert decrypted_message == message + assert bittensor.keyfile_data_is_encrypted(encrypted_message) + assert not bittensor.keyfile_data_is_encrypted(decrypted_message) + assert not bittensor.keyfile_data_is_encrypted_nacl(decrypted_message) + assert bittensor.keyfile_data_is_encrypted_ansible(encrypted_message) + + def test_check_and_update_encryption_not_updated(self): + """Test for a few cases where wallet should not be updated. + 1. When the wallet is already updated. + 2. When it is the hotkey. + 3. When the wallet is empty. + 4. When the wallet is legacy but no prompt to ask for password. + 5. When the password is wrong. + """ + # test the checking with no rewriting needs to be done. + with patch("bittensor.encrypt_keyfile_data") as encrypt: + # self.wallet is already the most updated with nacl encryption. + assert self.wallet.coldkey_file.check_and_update_encryption() + + # hotkey_file is not encrypted, thus do not need to be updated. + assert not self.wallet.hotkey_file.check_and_update_encryption() + + # empty_wallet has not been created, thus do not need to be updated. + assert not self.empty_wallet.coldkey_file.check_and_update_encryption() + + # legacy wallet cannot be updated without asking for password form prompt. + assert not self.legacy_wallet.coldkey_file.check_and_update_encryption( + no_prompt=True + ) + + # Wrong password + legacy_wallet = self.create_legacy_wallet() + with patch("getpass.getpass", return_value="wrong_password"), patch.object( + Confirm, "ask", return_value=False + ): + assert not legacy_wallet.coldkey_file.check_and_update_encryption() + + # no renewal has been done in this test. + assert not encrypt.called + + def test_check_and_update_excryption(self, legacy_wallet=None): + """Test for the alignment of the updated VS old wallet. + 1. Same coldkey_file data. + 2. Same coldkey path. + 3. Same hotkey_file data. + 4. Same hotkey path. + 5. same password. + + Read the updated wallet in 2 ways. + 1. Directly as the output of check_and_update_encryption() + 2. Read from file using the same coldkey and hotkey name + """ + + def check_new_coldkey_file(keyfile): + new_keyfile_data = keyfile._read_keyfile_data_from_file() + new_decrypted_keyfile_data = bittensor.decrypt_keyfile_data( + new_keyfile_data, legacy_password + ) + new_path = legacy_wallet.coldkey_file.path + + assert old_coldkey_file_data != None + assert new_keyfile_data != None + assert not old_coldkey_file_data == new_keyfile_data + assert bittensor.keyfile_data_is_encrypted_ansible(old_coldkey_file_data) + assert bittensor.keyfile_data_is_encrypted_nacl(new_keyfile_data) + assert not bittensor.keyfile_data_is_encrypted_nacl(old_coldkey_file_data) + assert not bittensor.keyfile_data_is_encrypted_ansible(new_keyfile_data) + assert old_decrypted_coldkey_file_data == new_decrypted_keyfile_data + assert new_path == old_coldkey_path + + def check_new_hotkey_file(keyfile): + new_keyfile_data = keyfile._read_keyfile_data_from_file() + new_path = legacy_wallet.hotkey_file.path + + assert old_hotkey_file_data == new_keyfile_data + assert new_path == old_hotkey_path + assert not bittensor.keyfile_data_is_encrypted(new_keyfile_data) + + if legacy_wallet == None: + legacy_password = f"PASSword-{random.randint(0, 10000)}" + legacy_wallet = self.create_legacy_wallet(legacy_password=legacy_password) + + else: + legacy_password = self.default_legacy_password + + # get old cold keyfile data + old_coldkey_file_data = ( + legacy_wallet.coldkey_file._read_keyfile_data_from_file() + ) + old_decrypted_coldkey_file_data = bittensor.decrypt_keyfile_data( + old_coldkey_file_data, legacy_password + ) + old_coldkey_path = legacy_wallet.coldkey_file.path + + # get old hot keyfile data + old_hotkey_file_data = legacy_wallet.hotkey_file._read_keyfile_data_from_file() + old_hotkey_path = legacy_wallet.hotkey_file.path + + # update legacy_wallet from ansible to nacl + with patch("getpass.getpass", return_value=legacy_password), patch.object( + Confirm, "ask", return_value=True + ): + legacy_wallet.coldkey_file.check_and_update_encryption() + + # get new keyfile data from the same legacy wallet + check_new_coldkey_file(legacy_wallet.coldkey_file) + check_new_hotkey_file(legacy_wallet.hotkey_file) + + # get new keyfile data from wallet name + updated_legacy_wallet = bittensor.wallet( + name=legacy_wallet.name, hotkey=legacy_wallet.hotkey_str + ) + check_new_coldkey_file(updated_legacy_wallet.coldkey_file) + check_new_hotkey_file(updated_legacy_wallet.hotkey_file) + + # def test_password_retain(self): + # [tick] test the same password works + # [tick] try to read using the same hotkey/coldkey name + # [tick] test the same keyfile data could be retained + # [tick] test what if a wrong password was inserted + # [no need] try to read from the new file path + # [tick] test the old and new encrypted is not the same + # [tick] test that the hotkeys are not affected + + class TestWallet(unittest.TestCase): def setUp(self): self.mock_wallet = bittensor.wallet(