From c3c10f1615631c3cca4d3c6be88c9c875e961d37 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Tue, 14 Nov 2023 22:08:29 -0500 Subject: [PATCH] wallet: implement optional cache validation Add a validate_cache parameter to the five principal caching methods: - _get_key_from_path - _get_keypair_from_path - _get_pubkey_from_path - get_script_from_path - get_address_from_path and to the five convenience methods that wrap the above: - get_script - get_addr - script_to_addr - get_new_script - get_new_addr The value of this new parameter defaults to False in all but the last two methods, where we are willing to sacrifice speed for the sake of extra confidence in the correctness of *new* scripts and addresses to be used for new deposits and new transactions. --- src/jmclient/wallet.py | 215 ++++++++++++++++++++++++------------ test/jmclient/test_taker.py | 9 +- 2 files changed, 149 insertions(+), 75 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index a580f96..b417f60 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -589,9 +589,11 @@ class BaseWallet(object): """ return cls._ENGINE.pubkey_to_address(pubkey) - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): path = self.script_to_path(script) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) def get_script_code(self, script): """ @@ -617,30 +619,42 @@ class BaseWallet(object): def get_key(self, mixdepth, address_type, index): raise NotImplementedError() - def get_addr(self, mixdepth, address_type, index): + def get_addr(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): cache = self._get_cache_for_path(path) addr = cache.get(b'A') - if addr is None: - engine = self._get_pubkey_from_path(path)[1] - script = self.get_script_from_path(path) - addr = engine.script_to_address(script) - cache[b'A'] = addr.encode('ascii') - else: + if addr is not None: addr = addr.decode('ascii') + if addr is None or validate_cache: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path, + validate_cache=validate_cache) + new_addr = engine.script_to_address(script) + if addr is None: + addr = new_addr + cache[b'A'] = addr.encode('ascii') + elif addr != new_addr: + raise WalletError("Wallet cache validation failed") return addr - def get_new_addr(self, mixdepth, address_type): + def get_new_addr(self, mixdepth, address_type, + validate_cache: bool = True): """ use get_external_addr/get_internal_addr """ - script = self.get_new_script(mixdepth, address_type) - return self.script_to_addr(script) + script = self.get_new_script(mixdepth, address_type, + validate_cache=validate_cache) + return self.script_to_addr(script, + validate_cache=validate_cache) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): raise NotImplementedError() def get_wif(self, mixdepth, address_type, index): @@ -938,7 +952,8 @@ class BaseWallet(object): def _get_mixdepth_from_path(self, path): raise NotImplementedError() - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): """ internal note: This is the final sink for all operations that somehow need to derive a script. If anything goes wrong when deriving a @@ -951,30 +966,43 @@ class BaseWallet(object): """ cache = self._get_cache_for_path(path) script = cache.get(b'S') - if script is None: - pubkey, engine = self._get_pubkey_from_path(path) - script = engine.pubkey_to_script(pubkey) - cache[b'S'] = script + if script is None or validate_cache: + pubkey, engine = self._get_pubkey_from_path(path, + validate_cache=validate_cache) + new_script = engine.pubkey_to_script(pubkey) + if script is None: + cache[b'S'] = script = new_script + elif script != new_script: + raise WalletError("Wallet cache validation failed") return script - def get_script(self, mixdepth, address_type, index): + def get_script(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) + return self.get_script_from_path(path, validate_cache=validate_cache) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise NotImplementedError() - def _get_keypair_from_path(self, path): - privkey, engine = self._get_key_from_path(path) + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + privkey, engine = self._get_key_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, engine - def _get_pubkey_from_path(self, path): - privkey, pubkey, engine = self._get_keypair_from_path(path) + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): + privkey, pubkey, engine = self._get_keypair_from_path(path, + validate_cache=validate_cache) return pubkey, engine def _get_cache_keys_for_path(self, path): @@ -1952,9 +1980,11 @@ class ImportWalletMixin(object): assert len(path) == 3 return path[1] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_imported_path(path): - return super()._get_key_from_path(path) + return super()._get_key_from_path(path, + validate_cache=validate_cache) assert len(path) == 3 md, i = path[1], path[2] @@ -2201,15 +2231,19 @@ class BIP32Wallet(BaseWallet): self._set_index_cache(md, address_type, current_index + 1) self._populate_maps((path,)) - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_script_from_path(path) + return super().get_script_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_address_from_path(path) + return super().get_address_from_path(path, + validate_cache=validate_cache) def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: @@ -2273,28 +2307,40 @@ class BIP32Wallet(BaseWallet): return path[len(self._get_bip32_base_path())] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return privkey, self._ENGINE - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, self._ENGINE def _get_cache_keys_for_path(self, path): @@ -2309,12 +2355,14 @@ class BIP32Wallet(BaseWallet): def is_standard_wallet_script(self, path): return self._is_my_bip32_path(path) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") index = self._index_cache[mixdepth][address_type] - return self.get_script(mixdepth, address_type, index) + return self.get_script(mixdepth, address_type, index, + validate_cache=validate_cache) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2560,35 +2608,47 @@ class FidelityBondMixin(object): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), engine else: return super()._get_key_from_path(path) - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self.is_timelocked_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), (pubkey, locktime), engine def _get_cache_for_path(self, path): @@ -2740,30 +2800,41 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet): path = super()._get_bip32_export_path(mixdepth, address_type) return path - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_pubkey_from_path(self, path): + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_pubkey_from_path(path) + return super()._get_pubkey_from_path(path, + validate_cache=validate_cache) if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] cache = self._get_cache_for_path(key_path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + if pubkey is None or validate_cache: + new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( self._master_key, key_path) - cache[b'P'] = pubkey + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (pubkey, locktime), self._TIMELOCK_ENGINE cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.derive_bip32_privkey( + self._master_key, path) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return pubkey, self._ENGINE diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index da902f3..7067382 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,7 +121,8 @@ class DummyWallet(LegacyWallet): """ return 'p2wpkh' - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if path[0] == b'dummy': return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE raise NotImplementedError() @@ -152,10 +153,12 @@ class DummyWallet(LegacyWallet): return False return True - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): if self.script_to_path(script)[0] == "nonstandard_path": return "dummyaddr" - return super().script_to_addr(script) + return super().script_to_addr(script, + validate_cache=validate_cache) def dummy_order_chooser():