From 5bc7eb4b8e3057ed3d60537c518f6f8918f4431d Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 18:56:46 -0400 Subject: [PATCH] wallet: add persistent cache, mapping path->(priv, pub, script, addr) Deriving private keys from BIP32 paths, public keys from private keys, scripts from public keys, and addresses from scripts are some of the most CPU-intensive tasks the wallet performs. Once the wallet inevitably accumulates thousands of used paths, startup times become painful due to needing to re-derive these data items for every used path in the wallet upon every startup. Introduce a persistent cache to avoid the need to re-derive these items every time the wallet is opened. Introduce _get_keypair_from_path and _get_pubkey_from_path methods to allow cached public keys to be used rather than always deriving them on the fly. Change many code paths that were calling CPU-intensive methods of BTCEngine so that instead they call _get_key_from_path, _get_keypair_from_path, _get_pubkey_from_path, get_script_from_path, and/or get_address_from_path, all of which can take advantage of the new cache. --- src/jmclient/wallet.py | 194 ++++++++++++++++++++++++++++++------ test/jmclient/test_taker.py | 9 +- 2 files changed, 171 insertions(+), 32 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 27c1d55..a580f96 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -374,6 +374,7 @@ class BaseWallet(object): self._storage = storage self._utxos = None self._addr_labels = None + self._cache = None # highest mixdepth ever used in wallet, important for synching self.max_mixdepth = None # effective maximum mixdepth to be used by joinmarket @@ -388,6 +389,7 @@ class BaseWallet(object): self._load_storage() assert self._utxos is not None + assert self._cache is not None assert self.max_mixdepth is not None assert self.max_mixdepth >= 0 assert self.network in ('mainnet', 'testnet', 'signet') @@ -424,6 +426,7 @@ class BaseWallet(object): self.network = self._storage.data[b'network'].decode('ascii') self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._addr_labels = AddressLabelsManager(self._storage) + self._cache = self._storage.data.setdefault(b'cache', {}) def get_storage_location(self): """ Return the location of the @@ -564,21 +567,31 @@ class BaseWallet(object): @classmethod def addr_to_script(cls, addr): + """ + Try not to call this slow method. Instead, call addr_to_path, + followed by get_script_from_path, as those are cached. + """ return cls._ENGINE.address_to_script(addr) @classmethod def pubkey_to_script(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_script_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_script(pubkey) @classmethod def pubkey_to_addr(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_address_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_address(pubkey) def script_to_addr(self, script): - assert self.is_known_script(script) path = self.script_to_path(script) - engine = self._get_key_from_path(path)[1] - return engine.script_to_address(script) + return self.get_address_from_path(path) def get_script_code(self, script): """ @@ -589,8 +602,7 @@ class BaseWallet(object): For non-segwit wallets, raises EngineError. """ path = self.script_to_path(script) - priv, engine = self._get_key_from_path(path) - pub = engine.privkey_to_pubkey(priv) + pub, engine = self._get_pubkey_from_path(path) return engine.pubkey_to_script_code(pub) @classmethod @@ -606,12 +618,20 @@ class BaseWallet(object): raise NotImplementedError() def get_addr(self, mixdepth, address_type, index): - script = self.get_script(mixdepth, address_type, index) - return self.script_to_addr(script) + path = self.get_path(mixdepth, address_type, index) + return self.get_address_from_path(path) def get_address_from_path(self, path): - script = self.get_script_from_path(path) - return self.script_to_addr(script) + cache = self._get_cache_for_path(path) + addr = cache.get(b'A') + if addr is None: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path) + addr = engine.script_to_address(script) + cache[b'A'] = addr.encode('ascii') + else: + addr = addr.decode('ascii') + return addr def get_new_addr(self, mixdepth, address_type): """ @@ -929,8 +949,13 @@ class BaseWallet(object): returns: script """ - priv, engine = self._get_key_from_path(path) - return engine.key_to_script(priv) + cache = self._get_cache_for_path(path) + script = cache.get(b'S') + if script is None: + pubkey, engine = self._get_pubkey_from_path(path) + script = engine.pubkey_to_script(pubkey) + cache[b'S'] = script + return script def get_script(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -939,6 +964,44 @@ class BaseWallet(object): def _get_key_from_path(self, path): raise NotImplementedError() + def _get_keypair_from_path(self, path): + privkey, engine = self._get_key_from_path(path) + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = engine.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return privkey, pubkey, engine + + def _get_pubkey_from_path(self, path): + privkey, pubkey, engine = self._get_keypair_from_path(path) + return pubkey, engine + + def _get_cache_keys_for_path(self, path): + return path[:1] + tuple(map(_int_to_bytestr, path[1:])) + + def _get_cache_for_path(self, path): + assert len(path) > 0 + cache = self._cache + for k in self._get_cache_keys_for_path(path): + cache = cache.setdefault(k, {}) + return cache + + def _delete_cache_for_path(self, path) -> bool: + assert len(path) > 0 + def recurse(cache, itr): + k = next(itr, None) + if k is None: + cache.clear() + else: + child = cache.get(k) + if not child or not recurse(child, itr): + return False + if not child: + del cache[k] + return True + return recurse(self._cache, iter(self._get_cache_keys_for_path(path))) + def get_path_repr(self, path): """ Get a human-readable representation of the wallet path. @@ -993,7 +1056,7 @@ class BaseWallet(object): signature as base64-encoded string """ priv, engine = self._get_key_from_path(path) - addr = engine.privkey_to_address(priv) + addr = self.get_address_from_path(path) return addr, engine.sign_message(priv, message) def get_wallet_name(self): @@ -1083,9 +1146,8 @@ class BaseWallet(object): def _populate_maps(self, paths): for path in paths: - script = self.get_script_from_path(path) - self._script_map[script] = path - self._addr_map[self.script_to_addr(script)] = path + self._script_map[self.get_script_from_path(path)] = path + self._addr_map[self.get_address_from_path(path)] = path def addr_to_path(self, addr): assert isinstance(addr, str) @@ -1399,9 +1461,8 @@ class PSBTWalletMixin(object): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) return new_psbt def sign_psbt(self, in_psbt, with_sign_result=False): @@ -1471,9 +1532,8 @@ class PSBTWalletMixin(object): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) # no else branch; any other form of scriptPubKey will just be # ignored. try: @@ -1871,13 +1931,14 @@ class ImportWalletMixin(object): if not script: script = self.get_script_from_path(path) if not address: - address = self.script_to_addr(script) + address = self.get_address_from_path(path) # we need to retain indices self._imported[path[1]][path[2]] = (b'', -1) del self._script_map[script] del self._addr_map[address] + self._delete_cache_for_path(path) def _cache_imported_key(self, mixdepth, privkey, key_type, index): path = (self._IMPORTED_ROOT_PATH, mixdepth, index) @@ -1916,7 +1977,7 @@ class ImportWalletMixin(object): def is_standard_wallet_script(self, path): if self._is_imported_path(path): - engine = self._get_key_from_path(path)[1] + engine = self._get_pubkey_from_path(path)[1] return engine == self._ENGINE return super().is_standard_wallet_script(path) @@ -2164,7 +2225,6 @@ class BIP32Wallet(BaseWallet): assert isinstance(index, Integral) if address_type is None: raise Exception("address_type must be set if index is set") - assert index <= self._index_cache[mixdepth][address_type] assert index < self.BIP32_MAX_PATH_LEVEL return tuple(chain(self._get_bip32_export_path(mixdepth, address_type), (index,))) @@ -2216,9 +2276,32 @@ class BIP32Wallet(BaseWallet): def _get_key_from_path(self, path): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) - - return self._ENGINE.derive_bip32_privkey(self._master_key, path), \ - self._ENGINE + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None: + privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'p'] = privkey + return privkey, self._ENGINE + + def _get_keypair_from_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_keypair_from_path(path) + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None: + privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'p'] = privkey + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._ENGINE.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return privkey, pubkey, self._ENGINE + + def _get_cache_keys_for_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_cache_keys_for_path(path) + return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii') + for lvl in path[1:]]) def _is_my_bip32_path(self, path): return len(path) > 0 and path[0] == self._key_ident @@ -2431,8 +2514,7 @@ class FidelityBondMixin(object): def _get_key_ident(self): first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID) - priv, engine = self._get_key_from_path(first_path) - pub = engine.privkey_to_pubkey(priv) + pub = self._get_pubkey_from_path(first_path)[0] return sha256(sha256(pub).digest()).digest()[:3] def is_standard_wallet_script(self, path): @@ -2483,11 +2565,37 @@ class FidelityBondMixin(object): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE - privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None: + privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache[b'p'] = privkey return (privkey, locktime), engine else: return super()._get_key_from_path(path) + def _get_keypair_from_path(self, path): + if not self.is_timelocked_path(path): + return super()._get_keypair_from_path(path) + key_path = path[:-1] + locktime = path[-1] + engine = self._TIMELOCK_ENGINE + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None: + privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache[b'p'] = privkey + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = engine.privkey_to_pubkey(privkey) + cache[b'P'] = pubkey + return (privkey, locktime), (pubkey, locktime), engine + + def _get_cache_for_path(self, path): + if self.is_timelocked_path(path): + path = path[:-1] + return super()._get_cache_for_path(path) + def get_path(self, mixdepth=None, address_type=None, index=None): if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID, self.BIP32_BURN_ID) or index == None: @@ -2632,6 +2740,32 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet): path = super()._get_bip32_export_path(mixdepth, address_type) return path + def _get_key_from_path(self, path): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_keypair_from_path(self, path): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_pubkey_from_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_pubkey_from_path(path) + if self.is_timelocked_path(path): + key_path = path[:-1] + locktime = path[-1] + cache = self._get_cache_for_path(key_path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + self._master_key, key_path) + cache[b'P'] = pubkey + return (pubkey, locktime), self._TIMELOCK_ENGINE + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None: + pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + cache[b'P'] = pubkey + return pubkey, self._ENGINE + WALLET_IMPLEMENTATIONS = { LegacyWallet.TYPE: LegacyWallet, diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index 8d20a43..da902f3 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,6 +121,11 @@ class DummyWallet(LegacyWallet): """ return 'p2wpkh' + def _get_key_from_path(self, path): + if path[0] == b'dummy': + return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE + raise NotImplementedError() + def get_key_from_addr(self, addr): """usable addresses: privkey all 1s, 2s, 3s, ... :""" privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]] @@ -139,8 +144,8 @@ class DummyWallet(LegacyWallet): return p raise ValueError("No such keypair") - def _is_my_bip32_path(self, path): - return True + def get_path_repr(self, path): + return '/'.join(map(str, path)) def is_standard_wallet_script(self, path): if path[0] == "nonstandard_path":