Browse Source

Merge Joinmarket-Org/joinmarket-clientserver#1594: Improve performance with better algorithms and caching of expensive computed values

c3c10f1615 wallet: implement optional cache validation (Matt Whitlock)
5bc7eb4b8e wallet: add persistent cache, mapping path->(priv, pub, script, addr) (Matt Whitlock)
01ec2a4181 wallet: add _addr_map, paralleling _script_map (Matt Whitlock)
64f18bce18 get_imported_privkey_branch: use O(m+n) algorithm instead of O(m*n) (Matt Whitlock)
77f0194a37 wallet_utils: use new get_utxos_at_mixdepth method (Matt Whitlock)
184d76f7f7 wallet: add get_{balance,utxos}_at_mixdepth methods (Matt Whitlock)
fc1e00058b wallet_showutxos: use O(1) check for frozen instead of O(n) (Matt Whitlock)
b58ac679cb wallet: drop _get_addr_int_ext; replace with calls to get_new_addr (Matt Whitlock)
2c38a813fc wallet: delete redundant get_script and get_addr methods (Matt Whitlock)
574c29e899 wallet: hoist get_script_from_path default impl into BaseWallet (Matt Whitlock)
8245271d7f wallet: avoid IndexError in _is_my_bip32_path (Matt Whitlock)
48aec83d76 wallet: remove a dead store in get_index_cache_and_increment (Matt Whitlock)

Pull request description:

  **Note:** Reviewing each commit individually will make more sense than trying to review the combined diff.

  This PR implements several performance enhancements that take the CPU time to run `wallet-tool.py display` on my wallet down from ~44 minutes to ~11 seconds.

  The most significant gains come from replacing an **O**(_m_*_n_) algorithm in `get_imported_privkey_branch` with a semantically equivalent **O**(_m_+_n_) algorithm and from adding a persistent cache for computed private keys, public keys, scripts, and addresses.

  Below are some actual benchmarks on my wallet, which has 5 mixdepths, each having path indices reaching into the 4000s, and almost 700 imported private keys.

  * 673fbfb9a5 `origin/master` (baseline)
      ```
      user    44m3.618s
      sys     0m6.375s
      ```
  * 48aec83d76 `wallet`: remove a dead store in `get_index_cache_and_increment`
  * fbb681a207be465fb53b43ac18a2b52c8a4a6323 `wallet`: add `get_{balance,utxos}_at_mixdepth` methods
  * 75a970378579bb04f189e8d9eca22e5e2aadb0b4 `wallet_utils`: use new `get_utxos_at_mixdepth` method
      ```
      user    42m14.464s
      sys     0m3.355s
      ```
  * 84966e628d510ddf0cadba170346ea926dc06000 `wallet_showutxos`: use **O**(1) check for frozen instead of **O**(_n_)
  * 75c5a75468a6de88e64c4af7a8226c633d358fd5 `get_imported_privkey_branch`: use **O**(_m_+_n_) algorithm instead of **O**(_m_*_n_)
      ```
      user    5m0.045s
      sys     0m0.453s
      ```
  * da8daf048369081d882fb591d50583559a2284f0 `wallet`: add `_addr_map`, paralleling `_script_map`
      ```
      user    4m56.175s
      sys     0m0.423s
      ```
  * d8aa1afe6f0ec596bb133f594ae88cc2fffb6ad2 `wallet`: add persistent cache, mapping path->(priv, pub, script, addr)
      ```
      user    1m42.272s
      sys     0m0.471s
      ```
  * After running another command to modify the wallet file so as to persist the cache, `wallet-tool.py display` now runs in:
      ```
      user    0m11.141s
      sys     0m0.225s
      ```

ACKs for top commit:
  AdamISZ:
    tACK c3c10f1615

Tree-SHA512: fdd20d436d8f16a1e4270011ad1ba4bf6393f876eb7413da30f75d5830249134911d5d93cab8051c0bf107c213d4cd46ba9614ae23eef4566f867ff1b912fc9b
master
Adam Gibson 2 years ago
parent
commit
2d3e90a079
No known key found for this signature in database
GPG Key ID: 141001A1AF77F20B
  1. 541
      src/jmclient/wallet.py
  2. 23
      src/jmclient/wallet_utils.py
  3. 16
      test/jmclient/test_taker.py
  4. 24
      test/jmclient/test_utxomanager.py
  5. 14
      test/jmclient/test_wallet.py

541
src/jmclient/wallet.py

@ -19,6 +19,7 @@ from itertools import chain
from decimal import Decimal
from numbers import Integral
from math import exp
from typing import Any, Dict, Optional, Tuple
from .configure import jm_single
@ -280,32 +281,28 @@ class UTXOManager(object):
'value': utxos[s['utxo']][1]}
for s in selected}
def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'),
include_disabled=True, maxheight=None):
""" By default this returns a dict of aggregated bitcoin
balance per mixdepth: {0: N sats, 1: M sats, ...} for all
currently available mixdepths.
If max_mixdepth is set it will return balances only up
to that mixdepth.
def get_balance_at_mixdepth(self, mixdepth: int,
include_disabled: bool = True,
maxheight: Optional[int] = None) -> int:
""" By default this returns aggregated bitcoin balance at mixdepth.
To get only enabled balance, set include_disabled=False.
To get balances only with a certain number of confs, use maxheight.
"""
balance_dict = collections.defaultdict(int)
for mixdepth, utxomap in self._utxo.items():
if mixdepth > max_mixdepth:
continue
if not include_disabled:
utxomap = {k: v for k, v in utxomap.items(
) if not self.is_disabled(*k)}
if maxheight is not None:
utxomap = {k: v for k, v in utxomap.items(
) if v[2] <= maxheight}
value = sum(x[1] for x in utxomap.values())
balance_dict[mixdepth] = value
return balance_dict
def get_utxos_by_mixdepth(self):
return deepcopy(self._utxo)
utxomap = self._utxo.get(mixdepth)
if not utxomap:
return 0
if not include_disabled:
utxomap = {k: v for k, v in utxomap.items(
) if not self.is_disabled(*k)}
if maxheight is not None:
utxomap = {k: v for k, v in utxomap.items(
) if v[2] <= maxheight}
return sum(x[1] for x in utxomap.values())
def get_utxos_at_mixdepth(self, mixdepth: int) -> \
Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]:
utxomap = self._utxo.get(mixdepth)
return deepcopy(utxomap) if utxomap else {}
def __eq__(self, o):
return self._utxo == o._utxo and \
@ -377,6 +374,7 @@ class BaseWallet(object):
self._storage = storage
self._utxos = None
self._addr_labels = None
self._cache = None
# highest mixdepth ever used in wallet, important for synching
self.max_mixdepth = None
# effective maximum mixdepth to be used by joinmarket
@ -385,10 +383,13 @@ class BaseWallet(object):
# {script: path}, should always hold mappings for all "known" keys
self._script_map = {}
# {address: path}, should always hold mappings for all "known" keys
self._addr_map = {}
self._load_storage()
assert self._utxos is not None
assert self._cache is not None
assert self.max_mixdepth is not None
assert self.max_mixdepth >= 0
assert self.network in ('mainnet', 'testnet', 'signet')
@ -425,6 +426,7 @@ class BaseWallet(object):
self.network = self._storage.data[b'network'].decode('ascii')
self._utxos = UTXOManager(self._storage, self.merge_algorithm)
self._addr_labels = AddressLabelsManager(self._storage)
self._cache = self._storage.data.setdefault(b'cache', {})
def get_storage_location(self):
""" Return the location of the
@ -538,34 +540,24 @@ class BaseWallet(object):
"""
There should be no reason for code outside the wallet to need a privkey.
"""
script = self._ENGINE.address_to_script(addr)
path = self.script_to_path(script)
path = self.addr_to_path(addr)
privkey = self._get_key_from_path(path)[0]
return privkey
def _get_addr_int_ext(self, address_type, mixdepth):
if address_type == self.ADDRESS_TYPE_EXTERNAL:
script = self.get_external_script(mixdepth)
elif address_type == self.ADDRESS_TYPE_INTERNAL:
script = self.get_internal_script(mixdepth)
else:
assert 0
return self.script_to_addr(script)
def get_external_addr(self, mixdepth):
"""
Return an address suitable for external distribution, including funding
the wallet from other sources, or receiving payments or donations.
JoinMarket will never generate these addresses for internal use.
"""
return self._get_addr_int_ext(self.ADDRESS_TYPE_EXTERNAL, mixdepth)
return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_EXTERNAL)
def get_internal_addr(self, mixdepth):
"""
Return an address for internal usage, as change addresses and when
participating in transactions initiated by other parties.
"""
return self._get_addr_int_ext(self.ADDRESS_TYPE_INTERNAL, mixdepth)
return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_INTERNAL)
def get_external_script(self, mixdepth):
return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL)
@ -575,21 +567,33 @@ class BaseWallet(object):
@classmethod
def addr_to_script(cls, addr):
"""
Try not to call this slow method. Instead, call addr_to_path,
followed by get_script_from_path, as those are cached.
"""
return cls._ENGINE.address_to_script(addr)
@classmethod
def pubkey_to_script(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_script_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_script(pubkey)
@classmethod
def pubkey_to_addr(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_address_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_address(pubkey)
def script_to_addr(self, script):
assert self.is_known_script(script)
def script_to_addr(self, script,
validate_cache: bool = False):
path = self.script_to_path(script)
engine = self._get_key_from_path(path)[1]
return engine.script_to_address(script)
return self.get_address_from_path(path,
validate_cache=validate_cache)
def get_script_code(self, script):
"""
@ -600,8 +604,7 @@ class BaseWallet(object):
For non-segwit wallets, raises EngineError.
"""
path = self.script_to_path(script)
priv, engine = self._get_key_from_path(path)
pub = engine.privkey_to_pubkey(priv)
pub, engine = self._get_pubkey_from_path(path)
return engine.pubkey_to_script_code(pub)
@classmethod
@ -616,22 +619,42 @@ class BaseWallet(object):
def get_key(self, mixdepth, address_type, index):
raise NotImplementedError()
def get_addr(self, mixdepth, address_type, index):
script = self.get_script(mixdepth, address_type, index)
return self.script_to_addr(script)
def get_address_from_path(self, path):
script = self.get_script_from_path(path)
return self.script_to_addr(script)
def get_new_addr(self, mixdepth, address_type):
def get_addr(self, mixdepth, address_type, index,
validate_cache: bool = False):
path = self.get_path(mixdepth, address_type, index)
return self.get_address_from_path(path,
validate_cache=validate_cache)
def get_address_from_path(self, path,
validate_cache: bool = False):
cache = self._get_cache_for_path(path)
addr = cache.get(b'A')
if addr is not None:
addr = addr.decode('ascii')
if addr is None or validate_cache:
engine = self._get_pubkey_from_path(path)[1]
script = self.get_script_from_path(path,
validate_cache=validate_cache)
new_addr = engine.script_to_address(script)
if addr is None:
addr = new_addr
cache[b'A'] = addr.encode('ascii')
elif addr != new_addr:
raise WalletError("Wallet cache validation failed")
return addr
def get_new_addr(self, mixdepth, address_type,
validate_cache: bool = True):
"""
use get_external_addr/get_internal_addr
"""
script = self.get_new_script(mixdepth, address_type)
return self.script_to_addr(script)
script = self.get_new_script(mixdepth, address_type,
validate_cache=validate_cache)
return self.script_to_addr(script,
validate_cache=validate_cache)
def get_new_script(self, mixdepth, address_type):
def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
raise NotImplementedError()
def get_wif(self, mixdepth, address_type, index):
@ -845,10 +868,19 @@ class BaseWallet(object):
confirmations, set maxheight to max acceptable blockheight.
returns: {mixdepth: value}
"""
balances = collections.defaultdict(int)
for md in range(self.mixdepth + 1):
balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose,
include_disabled=include_disabled, maxheight=maxheight)
return balances
def get_balance_at_mixdepth(self, mixdepth,
verbose: bool = True,
include_disabled: bool = False,
maxheight: Optional[int] = None) -> int:
# TODO: verbose
return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth,
include_disabled=include_disabled,
maxheight=maxheight)
return self._utxos.get_balance_at_mixdepth(mixdepth,
include_disabled=include_disabled, maxheight=maxheight)
def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False):
"""
@ -859,25 +891,35 @@ class BaseWallet(object):
{'script': bytes, 'path': tuple, 'value': int}}}
(if `includeheight` is True, adds key 'height': int)
"""
mix_utxos = self._utxos.get_utxos_by_mixdepth()
script_utxos = collections.defaultdict(dict)
for md, data in mix_utxos.items():
if md > self.mixdepth:
continue
for md in range(self.mixdepth + 1):
script_utxos[md] = self.get_utxos_at_mixdepth(md,
include_disabled=include_disabled, includeheight=includeheight)
return script_utxos
def get_utxos_at_mixdepth(self, mixdepth: int,
include_disabled: bool = False,
includeheight: bool = False) -> \
Dict[Tuple[bytes, int], Dict[str, Any]]:
script_utxos = {}
if 0 <= mixdepth <= self.mixdepth:
data = self._utxos.get_utxos_at_mixdepth(mixdepth)
for utxo, (path, value, height) in data.items():
if not include_disabled and self._utxos.is_disabled(*utxo):
continue
script = self.get_script_from_path(path)
addr = self.get_address_from_path(path)
label = self.get_address_label(addr)
script_utxos[md][utxo] = {'script': script,
'path': path,
'value': value,
'address': addr,
'label': label}
script_utxo = {
'script': script,
'path': path,
'value': value,
'address': addr,
'label': label,
}
if includeheight:
script_utxos[md][utxo]['height'] = height
script_utxo['height'] = height
script_utxos[utxo] = script_utxo
return script_utxos
@ -910,7 +952,8 @@ class BaseWallet(object):
def _get_mixdepth_from_path(self, path):
raise NotImplementedError()
def get_script_from_path(self, path):
def get_script_from_path(self, path,
validate_cache: bool = False):
"""
internal note: This is the final sink for all operations that somehow
need to derive a script. If anything goes wrong when deriving a
@ -921,15 +964,72 @@ class BaseWallet(object):
returns:
script
"""
raise NotImplementedError()
cache = self._get_cache_for_path(path)
script = cache.get(b'S')
if script is None or validate_cache:
pubkey, engine = self._get_pubkey_from_path(path,
validate_cache=validate_cache)
new_script = engine.pubkey_to_script(pubkey)
if script is None:
cache[b'S'] = script = new_script
elif script != new_script:
raise WalletError("Wallet cache validation failed")
return script
def get_script(self, mixdepth, address_type, index):
def get_script(self, mixdepth, address_type, index,
validate_cache: bool = False):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
return self.get_script_from_path(path, validate_cache=validate_cache)
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
raise NotImplementedError()
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
privkey, engine = self._get_key_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, engine
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
privkey, pubkey, engine = self._get_keypair_from_path(path,
validate_cache=validate_cache)
return pubkey, engine
def _get_cache_keys_for_path(self, path):
return path[:1] + tuple(map(_int_to_bytestr, path[1:]))
def _get_cache_for_path(self, path):
assert len(path) > 0
cache = self._cache
for k in self._get_cache_keys_for_path(path):
cache = cache.setdefault(k, {})
return cache
def _delete_cache_for_path(self, path) -> bool:
assert len(path) > 0
def recurse(cache, itr):
k = next(itr, None)
if k is None:
cache.clear()
else:
child = cache.get(k)
if not child or not recurse(child, itr):
return False
if not child:
del cache[k]
return True
return recurse(self._cache, iter(self._get_cache_keys_for_path(path)))
def get_path_repr(self, path):
"""
Get a human-readable representation of the wallet path.
@ -984,7 +1084,7 @@ class BaseWallet(object):
signature as base64-encoded string
"""
priv, engine = self._get_key_from_path(path)
addr = engine.privkey_to_address(priv)
addr = self.get_address_from_path(path)
return addr, engine.sign_message(priv, message)
def get_wallet_name(self):
@ -1038,8 +1138,8 @@ class BaseWallet(object):
returns:
bool
"""
script = self.addr_to_script(addr)
return script in self._script_map
assert isinstance(addr, str)
return addr in self._addr_map
def is_known_script(self, script):
"""
@ -1054,8 +1154,8 @@ class BaseWallet(object):
return script in self._script_map
def get_addr_mixdepth(self, addr):
script = self.addr_to_script(addr)
return self.get_script_mixdepth(script)
path = self.addr_to_path(addr)
return self._get_mixdepth_from_path(path)
def get_script_mixdepth(self, script):
path = self.script_to_path(script)
@ -1068,16 +1168,26 @@ class BaseWallet(object):
returns:
path generator
"""
for s in self._script_map.values():
yield s
for md in range(self.max_mixdepth + 1):
for path in self.yield_imported_paths(md):
yield path
def _populate_maps(self, paths):
for path in paths:
self._script_map[self.get_script_from_path(path)] = path
self._addr_map[self.get_address_from_path(path)] = path
def addr_to_path(self, addr):
script = self.addr_to_script(addr)
return self.script_to_path(script)
assert isinstance(addr, str)
path = self._addr_map.get(addr)
assert path is not None
return path
def script_to_path(self, script):
assert script in self._script_map
return self._script_map[script]
assert isinstance(script, bytes)
path = self._script_map.get(script)
assert path is not None
return path
def set_next_index(self, mixdepth, address_type, index, force=False):
"""
@ -1379,9 +1489,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script.
continue
privkey, _ = self._get_key_from_path(path)
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(
btc.privkey_to_pubkey(privkey))
pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
return new_psbt
def sign_psbt(self, in_psbt, with_sign_result=False):
@ -1451,9 +1560,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script.
continue
privkey, _ = self._get_key_from_path(path)
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(
btc.privkey_to_pubkey(privkey))
pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
# no else branch; any other form of scriptPubKey will just be
# ignored.
try:
@ -1767,12 +1875,7 @@ class ImportWalletMixin(object):
for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items():
md = int(md)
self._imported[md] = keys
for index, (key, key_type) in enumerate(keys):
if not key:
# imported key was removed
continue
assert key_type in self._ENGINES
self._cache_imported_key(md, key, key_type, index)
self._populate_maps(self.yield_imported_paths(md))
def save(self):
import_data = {}
@ -1841,8 +1944,8 @@ class ImportWalletMixin(object):
raise Exception("Only one of script|address|path may be given.")
if address:
script = self.addr_to_script(address)
if script:
path = self.addr_to_path(address)
elif script:
path = self.script_to_path(script)
if not path:
@ -1855,18 +1958,19 @@ class ImportWalletMixin(object):
if not script:
script = self.get_script_from_path(path)
if not address:
address = self.get_address_from_path(path)
# we need to retain indices
self._imported[path[1]][path[2]] = (b'', -1)
del self._script_map[script]
del self._addr_map[address]
self._delete_cache_for_path(path)
def _cache_imported_key(self, mixdepth, privkey, key_type, index):
engine = self._ENGINES[key_type]
path = (self._IMPORTED_ROOT_PATH, mixdepth, index)
self._script_map[engine.key_to_script(privkey)] = path
self._populate_maps((path,))
return path
def _get_mixdepth_from_path(self, path):
@ -1876,9 +1980,11 @@ class ImportWalletMixin(object):
assert len(path) == 3
return path[1]
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_imported_path(path):
return super()._get_key_from_path(path)
return super()._get_key_from_path(path,
validate_cache=validate_cache)
assert len(path) == 3
md, i = path[1], path[2]
@ -1901,7 +2007,7 @@ class ImportWalletMixin(object):
def is_standard_wallet_script(self, path):
if self._is_imported_path(path):
engine = self._get_key_from_path(path)[1]
engine = self._get_pubkey_from_path(path)[1]
return engine == self._ENGINE
return super().is_standard_wallet_script(path)
@ -1932,13 +2038,6 @@ class ImportWalletMixin(object):
return super().get_details(path)
return path[1], 'imported', path[2]
def get_script_from_path(self, path):
if not self._is_imported_path(path):
return super().get_script_from_path(path)
priv, engine = self._get_key_from_path(path)
return engine.key_to_script(priv)
class BIP39WalletMixin(object):
"""
@ -2009,6 +2108,7 @@ class BIP32Wallet(BaseWallet):
def __init__(self, storage, **kwargs):
self._entropy = None
self._key_ident = None
# {mixdepth: {type: index}} with type being 0/1 corresponding
# to external/internal addresses
self._index_cache = None
@ -2027,7 +2127,7 @@ class BIP32Wallet(BaseWallet):
# used to verify paths for sanity checking and for wallet id creation
self._key_ident = b'' # otherwise get_bip32_* won't work
self._key_ident = self._get_key_ident()
self._populate_script_map()
self._populate_maps(self.yield_known_bip32_paths())
self.disable_new_scripts = False
@classmethod
@ -2073,13 +2173,14 @@ class BIP32Wallet(BaseWallet):
self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\
.digest()[:3]
def _populate_script_map(self):
def yield_known_paths(self):
return chain(super().yield_known_paths(), self.yield_known_bip32_paths())
def yield_known_bip32_paths(self):
for md in self._index_cache:
for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID):
for i in range(self._index_cache[md][address_type]):
path = self.get_path(md, address_type, i)
script = self.get_script_from_path(path)
self._script_map[script] = path
yield self.get_path(md, address_type, i)
def save(self):
for md, data in self._index_cache.items():
@ -2114,10 +2215,7 @@ class BIP32Wallet(BaseWallet):
def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID)
def get_script_from_path(self, path):
if not self._is_my_bip32_path(path):
raise WalletError("unable to get script for unknown key path")
def _check_path(self, path):
md, address_type, index = self.get_details(path)
if not 0 <= md <= self.max_mixdepth:
@ -2130,12 +2228,22 @@ class BIP32Wallet(BaseWallet):
and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID:
#special case for timelocked addresses because for them the
#concept of a "next address" cant be used
return self.get_new_script_override_disable(md, address_type)
priv, engine = self._get_key_from_path(path)
script = engine.key_to_script(priv)
return script
self._set_index_cache(md, address_type, current_index + 1)
self._populate_maps((path,))
def get_script_from_path(self, path,
validate_cache: bool = False):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_script_from_path(path,
validate_cache=validate_cache)
def get_address_from_path(self, path,
validate_cache: bool = False):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_address_from_path(path,
validate_cache=validate_cache)
def get_path(self, mixdepth=None, address_type=None, index=None):
if mixdepth is not None:
@ -2151,7 +2259,6 @@ class BIP32Wallet(BaseWallet):
assert isinstance(index, Integral)
if address_type is None:
raise Exception("address_type must be set if index is set")
assert index <= self._index_cache[mixdepth][address_type]
assert index < self.BIP32_MAX_PATH_LEVEL
return tuple(chain(self._get_bip32_export_path(mixdepth, address_type),
(index,)))
@ -2200,30 +2307,62 @@ class BIP32Wallet(BaseWallet):
return path[len(self._get_bip32_base_path())]
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
raise WalletError("Invalid path, unknown root: {}".format(path))
return self._ENGINE.derive_bip32_privkey(self._master_key, path), \
self._ENGINE
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return privkey, self._ENGINE
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, self._ENGINE
def _get_cache_keys_for_path(self, path):
if not self._is_my_bip32_path(path):
return super()._get_cache_keys_for_path(path)
return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii')
for lvl in path[1:]])
def _is_my_bip32_path(self, path):
return path[0] == self._key_ident
return len(path) > 0 and path[0] == self._key_ident
def is_standard_wallet_script(self, path):
return self._is_my_bip32_path(path)
def get_new_script(self, mixdepth, address_type):
def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
if self.disable_new_scripts:
raise RuntimeError("Obtaining new wallet addresses "
+ "disabled, due to nohistory mode")
return self.get_new_script_override_disable(mixdepth, address_type)
def get_new_script_override_disable(self, mixdepth, address_type):
# This is called by get_script_from_path and calls back there. We need to
# ensure all conditions match to avoid endless recursion.
index = self.get_index_cache_and_increment(mixdepth, address_type)
return self.get_script_and_update_map(mixdepth, address_type, index)
index = self._index_cache[mixdepth][address_type]
return self.get_script(mixdepth, address_type, index,
validate_cache=validate_cache)
def _set_index_cache(self, mixdepth, address_type, index):
""" Ensures that any update to index_cache dict only applies
@ -2232,22 +2371,6 @@ class BIP32Wallet(BaseWallet):
assert address_type in self._get_supported_address_types()
self._index_cache[mixdepth][address_type] = index
def get_index_cache_and_increment(self, mixdepth, address_type):
index = self._index_cache[mixdepth][address_type]
cur_index = self._index_cache[mixdepth][address_type]
self._set_index_cache(mixdepth, address_type, cur_index + 1)
return cur_index
def get_script_and_update_map(self, *args):
path = self.get_path(*args)
script = self.get_script_from_path(path)
self._script_map[script] = path
return script
def get_script(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
@deprecated
def get_key(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
@ -2392,6 +2515,10 @@ class FidelityBondMixin(object):
_BIP32_PUBKEY_PREFIX = "fbonds-mpk-"
def __init__(self, storage, **kwargs):
super().__init__(storage, **kwargs)
self._populate_maps(self.yield_fidelity_bond_paths())
@classmethod
def _time_number_to_timestamp(cls, timenumber):
"""
@ -2435,8 +2562,7 @@ class FidelityBondMixin(object):
def _get_key_ident(self):
first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID)
priv, engine = self._get_key_from_path(first_path)
pub = engine.privkey_to_pubkey(priv)
pub = self._get_pubkey_from_path(first_path)[0]
return sha256(sha256(pub).digest()).digest()[:3]
def is_standard_wallet_script(self, path):
@ -2451,14 +2577,14 @@ class FidelityBondMixin(object):
else:
return False
def _populate_script_map(self):
super()._populate_script_map()
def yield_known_paths(self):
return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths())
def yield_fidelity_bond_paths(self):
md = self.FIDELITY_BOND_MIXDEPTH
address_type = self.BIP32_TIMELOCK_ID
for timenumber in range(self.TIMENUMBER_COUNT):
path = self.get_path(md, address_type, timenumber)
script = self.get_script_from_path(path)
self._script_map[script] = path
yield self.get_path(md, address_type, timenumber)
def add_utxo(self, txid, index, script, value, height=None):
super().add_utxo(txid, index, script, value, height)
@ -2482,16 +2608,54 @@ class FidelityBondMixin(object):
def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID)
def _get_key_from_path(self, path):
def _get_key_from_path(self, path,
validate_cache: bool = False):
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
privkey = engine.derive_bip32_privkey(self._master_key, key_path)
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), engine
else:
return super()._get_key_from_path(path)
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self.is_timelocked_path(path):
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), (pubkey, locktime), engine
def _get_cache_for_path(self, path):
if self.is_timelocked_path(path):
path = path[:-1]
return super()._get_cache_for_path(path)
def get_path(self, mixdepth=None, address_type=None, index=None):
if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID,
self.BIP32_BURN_ID) or index == None:
@ -2537,14 +2701,6 @@ class FidelityBondMixin(object):
def _get_default_used_indices(self):
return {x: [0, 0, 0, 0] for x in range(self.max_mixdepth + 1)}
def get_script(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
def get_addr(self, mixdepth, address_type, index):
script = self.get_script(mixdepth, address_type, index)
return self.script_to_addr(script)
def add_burner_output(self, path, txhex, block_height, merkle_branch,
block_index, write=True):
"""
@ -2644,6 +2800,43 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet):
path = super()._get_bip32_export_path(mixdepth, address_type)
return path
def _get_key_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_pubkey_from_path(path,
validate_cache=validate_cache)
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
cache = self._get_cache_for_path(key_path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey(
self._master_key, key_path)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (pubkey, locktime), self._TIMELOCK_ENGINE
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.derive_bip32_privkey(
self._master_key, path)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return pubkey, self._ENGINE
WALLET_IMPLEMENTATIONS = {
LegacyWallet.TYPE: LegacyWallet,

23
src/jmclient/wallet_utils.py

@ -7,7 +7,7 @@ import sys
from datetime import datetime, timedelta
from optparse import OptionParser
from numbers import Integral
from collections import Counter
from collections import Counter, defaultdict
from itertools import islice, chain
from jmclient import (get_network, WALLET_IMPLEMENTATIONS, Storage, podle,
jm_single, WalletError, BaseWallet, VolatileStorage,
@ -403,15 +403,15 @@ def get_tx_info(txid, tx_cache=None):
def get_imported_privkey_branch(wallet_service, m, showprivkey):
entries = []
balance_by_script = defaultdict(int)
for data in wallet_service.get_utxos_at_mixdepth(m,
include_disabled=True).values():
balance_by_script[data['script']] += data['value']
for path in wallet_service.yield_imported_paths(m):
addr = wallet_service.get_address_from_path(path)
script = wallet_service.get_script_from_path(path)
balance = 0.0
for data in wallet_service.get_utxos_by_mixdepth(
include_disabled=True)[m].values():
if script == data['script']:
balance += data['value']
status = ('used' if balance > 0.0 else 'empty')
balance = balance_by_script.get(script, 0)
status = ('used' if balance else 'empty')
if showprivkey:
wip_privkey = wallet_service.get_wif_path(path)
else:
@ -431,9 +431,6 @@ def wallet_showutxos(wallet_service, showprivkey):
includeconfs=True)
for md in utxos:
(enabled, disabled) = get_utxos_enabled_disabled(wallet_service, md)
utxo_d = []
for k, v in disabled.items():
utxo_d.append(k)
for u, av in utxos[md].items():
success, us = utxo_to_utxostr(u)
assert success
@ -453,7 +450,7 @@ def wallet_showutxos(wallet_service, showprivkey):
'external': False,
'mixdepth': mixdepth,
'confirmations': av['confs'],
'frozen': True if u in utxo_d else False}
'frozen': u in disabled}
if showprivkey:
unsp[us]['privkey'] = wallet_service.get_wif_path(av['path'])
if locktime:
@ -1279,8 +1276,8 @@ def display_utxos_for_disable_choice_default(wallet_service, utxos_enabled,
def get_utxos_enabled_disabled(wallet_service, md):
""" Returns dicts for enabled and disabled separately
"""
utxos_enabled = wallet_service.get_utxos_by_mixdepth()[md]
utxos_all = wallet_service.get_utxos_by_mixdepth(include_disabled=True)[md]
utxos_enabled = wallet_service.get_utxos_at_mixdepth(md)
utxos_all = wallet_service.get_utxos_at_mixdepth(md, include_disabled=True)
utxos_disabled_keyset = set(utxos_all).difference(set(utxos_enabled))
utxos_disabled = {}
for u in utxos_disabled_keyset:

16
test/jmclient/test_taker.py

@ -121,6 +121,12 @@ class DummyWallet(LegacyWallet):
"""
return 'p2wpkh'
def _get_key_from_path(self, path,
validate_cache: bool = False):
if path[0] == b'dummy':
return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE
raise NotImplementedError()
def get_key_from_addr(self, addr):
"""usable addresses: privkey all 1s, 2s, 3s, ... :"""
privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]]
@ -139,18 +145,20 @@ class DummyWallet(LegacyWallet):
return p
raise ValueError("No such keypair")
def _is_my_bip32_path(self, path):
return True
def get_path_repr(self, path):
return '/'.join(map(str, path))
def is_standard_wallet_script(self, path):
if path[0] == "nonstandard_path":
return False
return True
def script_to_addr(self, script):
def script_to_addr(self, script,
validate_cache: bool = False):
if self.script_to_path(script)[0] == "nonstandard_path":
return "dummyaddr"
return super().script_to_addr(script)
return super().script_to_addr(script,
validate_cache=validate_cache)
def dummy_order_chooser():

24
test/jmclient/test_utxomanager.py

@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert not um.is_disabled(txid, index+2)
um.disable_utxo(txid, index+2)
utxos = um.get_utxos_by_mixdepth()
assert len(utxos[mixdepth]) == 1
assert len(utxos[mixdepth+1]) == 2
assert len(utxos[mixdepth+2]) == 0
assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1
assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2
assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0
balances = um.get_balance_by_mixdepth()
assert balances[mixdepth] == value
assert balances[mixdepth+1] == value * 2
assert um.get_balance_at_mixdepth(mixdepth) == value
assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2
um.remove_utxo(txid, index, mixdepth)
assert um.have_utxo(txid, index) == False
@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert um.have_utxo(txid, index) == False
assert um.have_utxo(txid, index+1) == mixdepth + 1
utxos = um.get_utxos_by_mixdepth()
assert len(utxos[mixdepth]) == 0
assert len(utxos[mixdepth+1]) == 1
assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0
assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1
balances = um.get_balance_by_mixdepth()
assert balances[mixdepth] == 0
assert balances[mixdepth+1] == value
assert balances[mixdepth+2] == 0
assert um.get_balance_at_mixdepth(mixdepth) == 0
assert um.get_balance_at_mixdepth(mixdepth+1) == value
assert um.get_balance_at_mixdepth(mixdepth+2) == 0
def test_utxomanager_select(setup_env_nodeps):

14
test/jmclient/test_wallet.py

@ -17,7 +17,6 @@ from jmclient import load_test_config, jm_single, BaseWallet, \
wallet_gettimelockaddress, UnknownAddressForLabel
from test_blockchaininterface import sync_test_wallet
from freezegun import freeze_time
from bitcointx.wallet import CCoinAddressError
pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind")
@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif):
mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
#wallet needs to know about the script beforehand
wallet.get_script_and_update_map(mixdepth, address_type, timenumber)
assert address == wallet.get_addr(mixdepth, address_type, timenumber)
assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber))
@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string):
m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
script = wallet.get_script_and_update_map(m, address_type, timenumber)
script = wallet.get_script(m, address_type, timenumber)
addr = wallet.script_to_addr(script)
addr_from_method = wallet_gettimelockaddress(wallet, locktime_string)
@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet):
wallet = SegwitWalletFidelityBonds(storage)
timenumber = 0
script = wallet.get_script_and_update_map(
script = wallet.get_script(
FidelityBondMixin.FIDELITY_BOND_MIXDEPTH,
FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber)
utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script))
@ -477,7 +473,7 @@ def test_get_bbm(setup_wallet):
wallet = get_populated_wallet(amount, num_tx)
# disable a utxo and check we can correctly report
# balance with the disabled flag off:
utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0]
utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0]
wallet.disable_utxo(*utxo_1)
balances = wallet.get_balance_by_mixdepth(include_disabled=True)
assert balances[0] == num_tx * amount
@ -610,7 +606,9 @@ def test_address_labels(setup_wallet):
wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4")
wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4",
"test")
with pytest.raises(CCoinAddressError):
# we no longer decode addresses just to see if we know about them,
# so we won't get a CCoinAddressError for invalid addresses
#with pytest.raises(CCoinAddressError):
wallet.get_address_label("badaddress")
wallet.set_address_label("badaddress", "test")

Loading…
Cancel
Save