Browse Source

wallet: implement optional cache validation

Add a validate_cache parameter to the five principal caching methods:

- _get_key_from_path
- _get_keypair_from_path
- _get_pubkey_from_path
- get_script_from_path
- get_address_from_path

and to the five convenience methods that wrap the above:

- get_script
- get_addr
- script_to_addr
- get_new_script
- get_new_addr

The value of this new parameter defaults to False in all but the last
two methods, where we are willing to sacrifice speed for the sake of
extra confidence in the correctness of *new* scripts and addresses to
be used for new deposits and new transactions.
master
Matt Whitlock 2 years ago
parent
commit
c3c10f1615
  1. 215
      src/jmclient/wallet.py
  2. 9
      test/jmclient/test_taker.py

215
src/jmclient/wallet.py

@ -589,9 +589,11 @@ class BaseWallet(object):
"""
return cls._ENGINE.pubkey_to_address(pubkey)
def script_to_addr(self, script):
def script_to_addr(self, script,
validate_cache: bool = False):
path = self.script_to_path(script)
return self.get_address_from_path(path)
return self.get_address_from_path(path,
validate_cache=validate_cache)
def get_script_code(self, script):
"""
@ -617,30 +619,42 @@ class BaseWallet(object):
def get_key(self, mixdepth, address_type, index):
raise NotImplementedError()
def get_addr(self, mixdepth, address_type, index):
def get_addr(self, mixdepth, address_type, index,
validate_cache: bool = False):
path = self.get_path(mixdepth, address_type, index)
return self.get_address_from_path(path)
return self.get_address_from_path(path,
validate_cache=validate_cache)
def get_address_from_path(self, path):
def get_address_from_path(self, path,
validate_cache: bool = False):
cache = self._get_cache_for_path(path)
addr = cache.get(b'A')
if addr is None:
engine = self._get_pubkey_from_path(path)[1]
script = self.get_script_from_path(path)
addr = engine.script_to_address(script)
cache[b'A'] = addr.encode('ascii')
else:
if addr is not None:
addr = addr.decode('ascii')
if addr is None or validate_cache:
engine = self._get_pubkey_from_path(path)[1]
script = self.get_script_from_path(path,
validate_cache=validate_cache)
new_addr = engine.script_to_address(script)
if addr is None:
addr = new_addr
cache[b'A'] = addr.encode('ascii')
elif addr != new_addr:
raise WalletError("Wallet cache validation failed")
return addr
def get_new_addr(self, mixdepth, address_type):
def get_new_addr(self, mixdepth, address_type,
validate_cache: bool = True):
"""
use get_external_addr/get_internal_addr
"""
script = self.get_new_script(mixdepth, address_type)
return self.script_to_addr(script)
script = self.get_new_script(mixdepth, address_type,
validate_cache=validate_cache)
return self.script_to_addr(script,
validate_cache=validate_cache)
def get_new_script(self, mixdepth, address_type):
def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
raise NotImplementedError()
def get_wif(self, mixdepth, address_type, index):
@ -938,7 +952,8 @@ class BaseWallet(object):
def _get_mixdepth_from_path(self, path):
raise NotImplementedError()
def get_script_from_path(self, path):
def get_script_from_path(self, path,
validate_cache: bool = False):
"""
internal note: This is the final sink for all operations that somehow
need to derive a script. If anything goes wrong when deriving a
@ -951,30 +966,43 @@ class BaseWallet(object):
"""
cache = self._get_cache_for_path(path)
script = cache.get(b'S')
if script is None:
pubkey, engine = self._get_pubkey_from_path(path)
script = engine.pubkey_to_script(pubkey)
cache[b'S'] = script
if script is None or validate_cache:
pubkey, engine = self._get_pubkey_from_path(path,
validate_cache=validate_cache)
new_script = engine.pubkey_to_script(pubkey)
if script is None:
cache[b'S'] = script = new_script
elif script != new_script:
raise WalletError("Wallet cache validation failed")
return script
def get_script(self, mixdepth, address_type, index):
def get_script(self, mixdepth, address_type, index,
validate_cache: bool = False):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
return self.get_script_from_path(path, validate_cache=validate_cache)
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
raise NotImplementedError()
def _get_keypair_from_path(self, path):
privkey, engine = self._get_key_from_path(path)
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
privkey, engine = self._get_key_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = engine.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, engine
def _get_pubkey_from_path(self, path):
privkey, pubkey, engine = self._get_keypair_from_path(path)
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
privkey, pubkey, engine = self._get_keypair_from_path(path,
validate_cache=validate_cache)
return pubkey, engine
def _get_cache_keys_for_path(self, path):
@ -1952,9 +1980,11 @@ class ImportWalletMixin(object):
assert len(path) == 3
return path[1]
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_imported_path(path):
return super()._get_key_from_path(path)
return super()._get_key_from_path(path,
validate_cache=validate_cache)
assert len(path) == 3
md, i = path[1], path[2]
@ -2201,15 +2231,19 @@ class BIP32Wallet(BaseWallet):
self._set_index_cache(md, address_type, current_index + 1)
self._populate_maps((path,))
def get_script_from_path(self, path):
def get_script_from_path(self, path,
validate_cache: bool = False):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_script_from_path(path)
return super().get_script_from_path(path,
validate_cache=validate_cache)
def get_address_from_path(self, path):
def get_address_from_path(self, path,
validate_cache: bool = False):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_address_from_path(path)
return super().get_address_from_path(path,
validate_cache=validate_cache)
def get_path(self, mixdepth=None, address_type=None, index=None):
if mixdepth is not None:
@ -2273,28 +2307,40 @@ class BIP32Wallet(BaseWallet):
return path[len(self._get_bip32_base_path())]
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
raise WalletError("Invalid path, unknown root: {}".format(path))
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None:
privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'p'] = privkey
if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return privkey, self._ENGINE
def _get_keypair_from_path(self, path):
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_keypair_from_path(path)
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None:
privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'p'] = privkey
if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._ENGINE.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, self._ENGINE
def _get_cache_keys_for_path(self, path):
@ -2309,12 +2355,14 @@ class BIP32Wallet(BaseWallet):
def is_standard_wallet_script(self, path):
return self._is_my_bip32_path(path)
def get_new_script(self, mixdepth, address_type):
def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
if self.disable_new_scripts:
raise RuntimeError("Obtaining new wallet addresses "
+ "disabled, due to nohistory mode")
index = self._index_cache[mixdepth][address_type]
return self.get_script(mixdepth, address_type, index)
return self.get_script(mixdepth, address_type, index,
validate_cache=validate_cache)
def _set_index_cache(self, mixdepth, address_type, index):
""" Ensures that any update to index_cache dict only applies
@ -2560,35 +2608,47 @@ class FidelityBondMixin(object):
def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID)
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None:
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache[b'p'] = privkey
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), engine
else:
return super()._get_key_from_path(path)
def _get_keypair_from_path(self, path):
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self.is_timelocked_path(path):
return super()._get_keypair_from_path(path)
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None:
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache[b'p'] = privkey
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = engine.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), (pubkey, locktime), engine
def _get_cache_for_path(self, path):
@ -2740,30 +2800,41 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet):
path = super()._get_bip32_export_path(mixdepth, address_type)
return path
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_keypair_from_path(self, path):
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_pubkey_from_path(self, path):
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_pubkey_from_path(path)
return super()._get_pubkey_from_path(path,
validate_cache=validate_cache)
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
cache = self._get_cache_for_path(key_path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey(
if pubkey is None or validate_cache:
new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey(
self._master_key, key_path)
cache[b'P'] = pubkey
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (pubkey, locktime), self._TIMELOCK_ENGINE
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'P'] = pubkey
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.derive_bip32_privkey(
self._master_key, path)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return pubkey, self._ENGINE

9
test/jmclient/test_taker.py

@ -121,7 +121,8 @@ class DummyWallet(LegacyWallet):
"""
return 'p2wpkh'
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if path[0] == b'dummy':
return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE
raise NotImplementedError()
@ -152,10 +153,12 @@ class DummyWallet(LegacyWallet):
return False
return True
def script_to_addr(self, script):
def script_to_addr(self, script,
validate_cache: bool = False):
if self.script_to_path(script)[0] == "nonstandard_path":
return "dummyaddr"
return super().script_to_addr(script)
return super().script_to_addr(script,
validate_cache=validate_cache)
def dummy_order_chooser():

Loading…
Cancel
Save