diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index a580f96..b417f60 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -589,9 +589,11 @@ class BaseWallet(object): """ return cls._ENGINE.pubkey_to_address(pubkey) - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): path = self.script_to_path(script) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) def get_script_code(self, script): """ @@ -617,30 +619,42 @@ class BaseWallet(object): def get_key(self, mixdepth, address_type, index): raise NotImplementedError() - def get_addr(self, mixdepth, address_type, index): + def get_addr(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_address_from_path(path) + return self.get_address_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): cache = self._get_cache_for_path(path) addr = cache.get(b'A') - if addr is None: - engine = self._get_pubkey_from_path(path)[1] - script = self.get_script_from_path(path) - addr = engine.script_to_address(script) - cache[b'A'] = addr.encode('ascii') - else: + if addr is not None: addr = addr.decode('ascii') + if addr is None or validate_cache: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path, + validate_cache=validate_cache) + new_addr = engine.script_to_address(script) + if addr is None: + addr = new_addr + cache[b'A'] = addr.encode('ascii') + elif addr != new_addr: + raise WalletError("Wallet cache validation failed") return addr - def get_new_addr(self, mixdepth, address_type): + def get_new_addr(self, mixdepth, address_type, + validate_cache: bool = True): """ use get_external_addr/get_internal_addr """ - script = self.get_new_script(mixdepth, address_type) - return self.script_to_addr(script) + script = self.get_new_script(mixdepth, address_type, + validate_cache=validate_cache) + return self.script_to_addr(script, + validate_cache=validate_cache) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): raise NotImplementedError() def get_wif(self, mixdepth, address_type, index): @@ -938,7 +952,8 @@ class BaseWallet(object): def _get_mixdepth_from_path(self, path): raise NotImplementedError() - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): """ internal note: This is the final sink for all operations that somehow need to derive a script. If anything goes wrong when deriving a @@ -951,30 +966,43 @@ class BaseWallet(object): """ cache = self._get_cache_for_path(path) script = cache.get(b'S') - if script is None: - pubkey, engine = self._get_pubkey_from_path(path) - script = engine.pubkey_to_script(pubkey) - cache[b'S'] = script + if script is None or validate_cache: + pubkey, engine = self._get_pubkey_from_path(path, + validate_cache=validate_cache) + new_script = engine.pubkey_to_script(pubkey) + if script is None: + cache[b'S'] = script = new_script + elif script != new_script: + raise WalletError("Wallet cache validation failed") return script - def get_script(self, mixdepth, address_type, index): + def get_script(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) + return self.get_script_from_path(path, validate_cache=validate_cache) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise NotImplementedError() - def _get_keypair_from_path(self, path): - privkey, engine = self._get_key_from_path(path) + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + privkey, engine = self._get_key_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, engine - def _get_pubkey_from_path(self, path): - privkey, pubkey, engine = self._get_keypair_from_path(path) + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): + privkey, pubkey, engine = self._get_keypair_from_path(path, + validate_cache=validate_cache) return pubkey, engine def _get_cache_keys_for_path(self, path): @@ -1952,9 +1980,11 @@ class ImportWalletMixin(object): assert len(path) == 3 return path[1] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_imported_path(path): - return super()._get_key_from_path(path) + return super()._get_key_from_path(path, + validate_cache=validate_cache) assert len(path) == 3 md, i = path[1], path[2] @@ -2201,15 +2231,19 @@ class BIP32Wallet(BaseWallet): self._set_index_cache(md, address_type, current_index + 1) self._populate_maps((path,)) - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_script_from_path(path) + return super().get_script_from_path(path, + validate_cache=validate_cache) - def get_address_from_path(self, path): + def get_address_from_path(self, path, + validate_cache: bool = False): if self._is_my_bip32_path(path): self._check_path(path) - return super().get_address_from_path(path) + return super().get_address_from_path(path, + validate_cache=validate_cache) def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: @@ -2273,28 +2307,40 @@ class BIP32Wallet(BaseWallet): return path[len(self._get_bip32_base_path())] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return privkey, self._ENGINE - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) cache = self._get_cache_for_path(path) privkey = cache.get(b'p') - if privkey is None: - privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return privkey, pubkey, self._ENGINE def _get_cache_keys_for_path(self, path): @@ -2309,12 +2355,14 @@ class BIP32Wallet(BaseWallet): def is_standard_wallet_script(self, path): return self._is_my_bip32_path(path) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") index = self._index_cache[mixdepth][address_type] - return self.get_script(mixdepth, address_type, index) + return self.get_script(mixdepth, address_type, index, + validate_cache=validate_cache) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2560,35 +2608,47 @@ class FidelityBondMixin(object): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), engine else: return super()._get_key_from_path(path) - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): if not self.is_timelocked_path(path): - return super()._get_keypair_from_path(path) + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE cache = super()._get_cache_for_path(key_path) privkey = cache.get(b'p') - if privkey is None: - privkey = engine.derive_bip32_privkey(self._master_key, key_path) - cache[b'p'] = privkey + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") pubkey = cache.get(b'P') - if pubkey is None: - pubkey = engine.privkey_to_pubkey(privkey) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), (pubkey, locktime), engine def _get_cache_for_path(self, path): @@ -2740,30 +2800,41 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet): path = super()._get_bip32_export_path(mixdepth, address_type) return path - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_keypair_from_path(self, path): + def _get_keypair_from_path(self, path, + validate_cache: bool = False): raise WalletError("Cannot get a private key from a watch-only wallet") - def _get_pubkey_from_path(self, path): + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): - return super()._get_pubkey_from_path(path) + return super()._get_pubkey_from_path(path, + validate_cache=validate_cache) if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] cache = self._get_cache_for_path(key_path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + if pubkey is None or validate_cache: + new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( self._master_key, key_path) - cache[b'P'] = pubkey + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return (pubkey, locktime), self._TIMELOCK_ENGINE cache = self._get_cache_for_path(path) pubkey = cache.get(b'P') - if pubkey is None: - pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) - cache[b'P'] = pubkey + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.derive_bip32_privkey( + self._master_key, path) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") return pubkey, self._ENGINE diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index da902f3..7067382 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,7 +121,8 @@ class DummyWallet(LegacyWallet): """ return 'p2wpkh' - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if path[0] == b'dummy': return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE raise NotImplementedError() @@ -152,10 +153,12 @@ class DummyWallet(LegacyWallet): return False return True - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): if self.script_to_path(script)[0] == "nonstandard_path": return "dummyaddr" - return super().script_to_addr(script) + return super().script_to_addr(script, + validate_cache=validate_cache) def dummy_order_chooser():