Browse Source

wallet: add persistent cache, mapping path->(priv, pub, script, addr)

Deriving private keys from BIP32 paths, public keys from private keys,
scripts from public keys, and addresses from scripts are some of the
most CPU-intensive tasks the wallet performs. Once the wallet inevitably
accumulates thousands of used paths, startup times become painful due to
needing to re-derive these data items for every used path in the wallet
upon every startup. Introduce a persistent cache to avoid the need to
re-derive these items every time the wallet is opened.

Introduce _get_keypair_from_path and _get_pubkey_from_path methods to
allow cached public keys to be used rather than always deriving them on
the fly.

Change many code paths that were calling CPU-intensive methods of
BTCEngine so that instead they call _get_key_from_path,
_get_keypair_from_path, _get_pubkey_from_path, get_script_from_path,
and/or get_address_from_path, all of which can take advantage of the new
cache.
master
Matt Whitlock 2 years ago
parent
commit
5bc7eb4b8e
  1. 194
      src/jmclient/wallet.py
  2. 9
      test/jmclient/test_taker.py

194
src/jmclient/wallet.py

@ -374,6 +374,7 @@ class BaseWallet(object):
self._storage = storage
self._utxos = None
self._addr_labels = None
self._cache = None
# highest mixdepth ever used in wallet, important for synching
self.max_mixdepth = None
# effective maximum mixdepth to be used by joinmarket
@ -388,6 +389,7 @@ class BaseWallet(object):
self._load_storage()
assert self._utxos is not None
assert self._cache is not None
assert self.max_mixdepth is not None
assert self.max_mixdepth >= 0
assert self.network in ('mainnet', 'testnet', 'signet')
@ -424,6 +426,7 @@ class BaseWallet(object):
self.network = self._storage.data[b'network'].decode('ascii')
self._utxos = UTXOManager(self._storage, self.merge_algorithm)
self._addr_labels = AddressLabelsManager(self._storage)
self._cache = self._storage.data.setdefault(b'cache', {})
def get_storage_location(self):
""" Return the location of the
@ -564,21 +567,31 @@ class BaseWallet(object):
@classmethod
def addr_to_script(cls, addr):
"""
Try not to call this slow method. Instead, call addr_to_path,
followed by get_script_from_path, as those are cached.
"""
return cls._ENGINE.address_to_script(addr)
@classmethod
def pubkey_to_script(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_script_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_script(pubkey)
@classmethod
def pubkey_to_addr(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_address_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_address(pubkey)
def script_to_addr(self, script):
assert self.is_known_script(script)
path = self.script_to_path(script)
engine = self._get_key_from_path(path)[1]
return engine.script_to_address(script)
return self.get_address_from_path(path)
def get_script_code(self, script):
"""
@ -589,8 +602,7 @@ class BaseWallet(object):
For non-segwit wallets, raises EngineError.
"""
path = self.script_to_path(script)
priv, engine = self._get_key_from_path(path)
pub = engine.privkey_to_pubkey(priv)
pub, engine = self._get_pubkey_from_path(path)
return engine.pubkey_to_script_code(pub)
@classmethod
@ -606,12 +618,20 @@ class BaseWallet(object):
raise NotImplementedError()
def get_addr(self, mixdepth, address_type, index):
script = self.get_script(mixdepth, address_type, index)
return self.script_to_addr(script)
path = self.get_path(mixdepth, address_type, index)
return self.get_address_from_path(path)
def get_address_from_path(self, path):
script = self.get_script_from_path(path)
return self.script_to_addr(script)
cache = self._get_cache_for_path(path)
addr = cache.get(b'A')
if addr is None:
engine = self._get_pubkey_from_path(path)[1]
script = self.get_script_from_path(path)
addr = engine.script_to_address(script)
cache[b'A'] = addr.encode('ascii')
else:
addr = addr.decode('ascii')
return addr
def get_new_addr(self, mixdepth, address_type):
"""
@ -929,8 +949,13 @@ class BaseWallet(object):
returns:
script
"""
priv, engine = self._get_key_from_path(path)
return engine.key_to_script(priv)
cache = self._get_cache_for_path(path)
script = cache.get(b'S')
if script is None:
pubkey, engine = self._get_pubkey_from_path(path)
script = engine.pubkey_to_script(pubkey)
cache[b'S'] = script
return script
def get_script(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
@ -939,6 +964,44 @@ class BaseWallet(object):
def _get_key_from_path(self, path):
raise NotImplementedError()
def _get_keypair_from_path(self, path):
privkey, engine = self._get_key_from_path(path)
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = engine.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
return privkey, pubkey, engine
def _get_pubkey_from_path(self, path):
privkey, pubkey, engine = self._get_keypair_from_path(path)
return pubkey, engine
def _get_cache_keys_for_path(self, path):
return path[:1] + tuple(map(_int_to_bytestr, path[1:]))
def _get_cache_for_path(self, path):
assert len(path) > 0
cache = self._cache
for k in self._get_cache_keys_for_path(path):
cache = cache.setdefault(k, {})
return cache
def _delete_cache_for_path(self, path) -> bool:
assert len(path) > 0
def recurse(cache, itr):
k = next(itr, None)
if k is None:
cache.clear()
else:
child = cache.get(k)
if not child or not recurse(child, itr):
return False
if not child:
del cache[k]
return True
return recurse(self._cache, iter(self._get_cache_keys_for_path(path)))
def get_path_repr(self, path):
"""
Get a human-readable representation of the wallet path.
@ -993,7 +1056,7 @@ class BaseWallet(object):
signature as base64-encoded string
"""
priv, engine = self._get_key_from_path(path)
addr = engine.privkey_to_address(priv)
addr = self.get_address_from_path(path)
return addr, engine.sign_message(priv, message)
def get_wallet_name(self):
@ -1083,9 +1146,8 @@ class BaseWallet(object):
def _populate_maps(self, paths):
for path in paths:
script = self.get_script_from_path(path)
self._script_map[script] = path
self._addr_map[self.script_to_addr(script)] = path
self._script_map[self.get_script_from_path(path)] = path
self._addr_map[self.get_address_from_path(path)] = path
def addr_to_path(self, addr):
assert isinstance(addr, str)
@ -1399,9 +1461,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script.
continue
privkey, _ = self._get_key_from_path(path)
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(
btc.privkey_to_pubkey(privkey))
pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
return new_psbt
def sign_psbt(self, in_psbt, with_sign_result=False):
@ -1471,9 +1532,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script.
continue
privkey, _ = self._get_key_from_path(path)
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(
btc.privkey_to_pubkey(privkey))
pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
# no else branch; any other form of scriptPubKey will just be
# ignored.
try:
@ -1871,13 +1931,14 @@ class ImportWalletMixin(object):
if not script:
script = self.get_script_from_path(path)
if not address:
address = self.script_to_addr(script)
address = self.get_address_from_path(path)
# we need to retain indices
self._imported[path[1]][path[2]] = (b'', -1)
del self._script_map[script]
del self._addr_map[address]
self._delete_cache_for_path(path)
def _cache_imported_key(self, mixdepth, privkey, key_type, index):
path = (self._IMPORTED_ROOT_PATH, mixdepth, index)
@ -1916,7 +1977,7 @@ class ImportWalletMixin(object):
def is_standard_wallet_script(self, path):
if self._is_imported_path(path):
engine = self._get_key_from_path(path)[1]
engine = self._get_pubkey_from_path(path)[1]
return engine == self._ENGINE
return super().is_standard_wallet_script(path)
@ -2164,7 +2225,6 @@ class BIP32Wallet(BaseWallet):
assert isinstance(index, Integral)
if address_type is None:
raise Exception("address_type must be set if index is set")
assert index <= self._index_cache[mixdepth][address_type]
assert index < self.BIP32_MAX_PATH_LEVEL
return tuple(chain(self._get_bip32_export_path(mixdepth, address_type),
(index,)))
@ -2216,9 +2276,32 @@ class BIP32Wallet(BaseWallet):
def _get_key_from_path(self, path):
if not self._is_my_bip32_path(path):
raise WalletError("Invalid path, unknown root: {}".format(path))
return self._ENGINE.derive_bip32_privkey(self._master_key, path), \
self._ENGINE
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None:
privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'p'] = privkey
return privkey, self._ENGINE
def _get_keypair_from_path(self, path):
if not self._is_my_bip32_path(path):
return super()._get_keypair_from_path(path)
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None:
privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'p'] = privkey
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._ENGINE.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
return privkey, pubkey, self._ENGINE
def _get_cache_keys_for_path(self, path):
if not self._is_my_bip32_path(path):
return super()._get_cache_keys_for_path(path)
return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii')
for lvl in path[1:]])
def _is_my_bip32_path(self, path):
return len(path) > 0 and path[0] == self._key_ident
@ -2431,8 +2514,7 @@ class FidelityBondMixin(object):
def _get_key_ident(self):
first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID)
priv, engine = self._get_key_from_path(first_path)
pub = engine.privkey_to_pubkey(priv)
pub = self._get_pubkey_from_path(first_path)[0]
return sha256(sha256(pub).digest()).digest()[:3]
def is_standard_wallet_script(self, path):
@ -2483,11 +2565,37 @@ class FidelityBondMixin(object):
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None:
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache[b'p'] = privkey
return (privkey, locktime), engine
else:
return super()._get_key_from_path(path)
def _get_keypair_from_path(self, path):
if not self.is_timelocked_path(path):
return super()._get_keypair_from_path(path)
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None:
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache[b'p'] = privkey
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = engine.privkey_to_pubkey(privkey)
cache[b'P'] = pubkey
return (privkey, locktime), (pubkey, locktime), engine
def _get_cache_for_path(self, path):
if self.is_timelocked_path(path):
path = path[:-1]
return super()._get_cache_for_path(path)
def get_path(self, mixdepth=None, address_type=None, index=None):
if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID,
self.BIP32_BURN_ID) or index == None:
@ -2632,6 +2740,32 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet):
path = super()._get_bip32_export_path(mixdepth, address_type)
return path
def _get_key_from_path(self, path):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_keypair_from_path(self, path):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_pubkey_from_path(self, path):
if not self._is_my_bip32_path(path):
return super()._get_pubkey_from_path(path)
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
cache = self._get_cache_for_path(key_path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey(
self._master_key, key_path)
cache[b'P'] = pubkey
return (pubkey, locktime), self._TIMELOCK_ENGINE
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None:
pubkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
cache[b'P'] = pubkey
return pubkey, self._ENGINE
WALLET_IMPLEMENTATIONS = {
LegacyWallet.TYPE: LegacyWallet,

9
test/jmclient/test_taker.py

@ -121,6 +121,11 @@ class DummyWallet(LegacyWallet):
"""
return 'p2wpkh'
def _get_key_from_path(self, path):
if path[0] == b'dummy':
return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE
raise NotImplementedError()
def get_key_from_addr(self, addr):
"""usable addresses: privkey all 1s, 2s, 3s, ... :"""
privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]]
@ -139,8 +144,8 @@ class DummyWallet(LegacyWallet):
return p
raise ValueError("No such keypair")
def _is_my_bip32_path(self, path):
return True
def get_path_repr(self, path):
return '/'.join(map(str, path))
def is_standard_wallet_script(self, path):
if path[0] == "nonstandard_path":

Loading…
Cancel
Save