Browse Source

Merge Joinmarket-Org/joinmarket-clientserver#1594: Improve performance with better algorithms and caching of expensive computed values

c3c10f1615 wallet: implement optional cache validation (Matt Whitlock)
5bc7eb4b8e wallet: add persistent cache, mapping path->(priv, pub, script, addr) (Matt Whitlock)
01ec2a4181 wallet: add _addr_map, paralleling _script_map (Matt Whitlock)
64f18bce18 get_imported_privkey_branch: use O(m+n) algorithm instead of O(m*n) (Matt Whitlock)
77f0194a37 wallet_utils: use new get_utxos_at_mixdepth method (Matt Whitlock)
184d76f7f7 wallet: add get_{balance,utxos}_at_mixdepth methods (Matt Whitlock)
fc1e00058b wallet_showutxos: use O(1) check for frozen instead of O(n) (Matt Whitlock)
b58ac679cb wallet: drop _get_addr_int_ext; replace with calls to get_new_addr (Matt Whitlock)
2c38a813fc wallet: delete redundant get_script and get_addr methods (Matt Whitlock)
574c29e899 wallet: hoist get_script_from_path default impl into BaseWallet (Matt Whitlock)
8245271d7f wallet: avoid IndexError in _is_my_bip32_path (Matt Whitlock)
48aec83d76 wallet: remove a dead store in get_index_cache_and_increment (Matt Whitlock)

Pull request description:

  **Note:** Reviewing each commit individually will make more sense than trying to review the combined diff.

  This PR implements several performance enhancements that take the CPU time to run `wallet-tool.py display` on my wallet down from ~44 minutes to ~11 seconds.

  The most significant gains come from replacing an **O**(_m_*_n_) algorithm in `get_imported_privkey_branch` with a semantically equivalent **O**(_m_+_n_) algorithm and from adding a persistent cache for computed private keys, public keys, scripts, and addresses.

  Below are some actual benchmarks on my wallet, which has 5 mixdepths, each having path indices reaching into the 4000s, and almost 700 imported private keys.

  * 673fbfb9a5 `origin/master` (baseline)
      ```
      user    44m3.618s
      sys     0m6.375s
      ```
  * 48aec83d76 `wallet`: remove a dead store in `get_index_cache_and_increment`
  * fbb681a207be465fb53b43ac18a2b52c8a4a6323 `wallet`: add `get_{balance,utxos}_at_mixdepth` methods
  * 75a970378579bb04f189e8d9eca22e5e2aadb0b4 `wallet_utils`: use new `get_utxos_at_mixdepth` method
      ```
      user    42m14.464s
      sys     0m3.355s
      ```
  * 84966e628d510ddf0cadba170346ea926dc06000 `wallet_showutxos`: use **O**(1) check for frozen instead of **O**(_n_)
  * 75c5a75468a6de88e64c4af7a8226c633d358fd5 `get_imported_privkey_branch`: use **O**(_m_+_n_) algorithm instead of **O**(_m_*_n_)
      ```
      user    5m0.045s
      sys     0m0.453s
      ```
  * da8daf048369081d882fb591d50583559a2284f0 `wallet`: add `_addr_map`, paralleling `_script_map`
      ```
      user    4m56.175s
      sys     0m0.423s
      ```
  * d8aa1afe6f0ec596bb133f594ae88cc2fffb6ad2 `wallet`: add persistent cache, mapping path->(priv, pub, script, addr)
      ```
      user    1m42.272s
      sys     0m0.471s
      ```
  * After running another command to modify the wallet file so as to persist the cache, `wallet-tool.py display` now runs in:
      ```
      user    0m11.141s
      sys     0m0.225s
      ```

ACKs for top commit:
  AdamISZ:
    tACK c3c10f1615

Tree-SHA512: fdd20d436d8f16a1e4270011ad1ba4bf6393f876eb7413da30f75d5830249134911d5d93cab8051c0bf107c213d4cd46ba9614ae23eef4566f867ff1b912fc9b
master
Adam Gibson 2 years ago
parent
commit
2d3e90a079
No known key found for this signature in database
GPG Key ID: 141001A1AF77F20B
  1. 541
      src/jmclient/wallet.py
  2. 23
      src/jmclient/wallet_utils.py
  3. 16
      test/jmclient/test_taker.py
  4. 24
      test/jmclient/test_utxomanager.py
  5. 14
      test/jmclient/test_wallet.py

541
src/jmclient/wallet.py

@ -19,6 +19,7 @@ from itertools import chain
from decimal import Decimal from decimal import Decimal
from numbers import Integral from numbers import Integral
from math import exp from math import exp
from typing import Any, Dict, Optional, Tuple
from .configure import jm_single from .configure import jm_single
@ -280,32 +281,28 @@ class UTXOManager(object):
'value': utxos[s['utxo']][1]} 'value': utxos[s['utxo']][1]}
for s in selected} for s in selected}
def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), def get_balance_at_mixdepth(self, mixdepth: int,
include_disabled=True, maxheight=None): include_disabled: bool = True,
""" By default this returns a dict of aggregated bitcoin maxheight: Optional[int] = None) -> int:
balance per mixdepth: {0: N sats, 1: M sats, ...} for all """ By default this returns aggregated bitcoin balance at mixdepth.
currently available mixdepths.
If max_mixdepth is set it will return balances only up
to that mixdepth.
To get only enabled balance, set include_disabled=False. To get only enabled balance, set include_disabled=False.
To get balances only with a certain number of confs, use maxheight. To get balances only with a certain number of confs, use maxheight.
""" """
balance_dict = collections.defaultdict(int) utxomap = self._utxo.get(mixdepth)
for mixdepth, utxomap in self._utxo.items(): if not utxomap:
if mixdepth > max_mixdepth: return 0
continue if not include_disabled:
if not include_disabled: utxomap = {k: v for k, v in utxomap.items(
utxomap = {k: v for k, v in utxomap.items( ) if not self.is_disabled(*k)}
) if not self.is_disabled(*k)} if maxheight is not None:
if maxheight is not None: utxomap = {k: v for k, v in utxomap.items(
utxomap = {k: v for k, v in utxomap.items( ) if v[2] <= maxheight}
) if v[2] <= maxheight} return sum(x[1] for x in utxomap.values())
value = sum(x[1] for x in utxomap.values())
balance_dict[mixdepth] = value def get_utxos_at_mixdepth(self, mixdepth: int) -> \
return balance_dict Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]:
utxomap = self._utxo.get(mixdepth)
def get_utxos_by_mixdepth(self): return deepcopy(utxomap) if utxomap else {}
return deepcopy(self._utxo)
def __eq__(self, o): def __eq__(self, o):
return self._utxo == o._utxo and \ return self._utxo == o._utxo and \
@ -377,6 +374,7 @@ class BaseWallet(object):
self._storage = storage self._storage = storage
self._utxos = None self._utxos = None
self._addr_labels = None self._addr_labels = None
self._cache = None
# highest mixdepth ever used in wallet, important for synching # highest mixdepth ever used in wallet, important for synching
self.max_mixdepth = None self.max_mixdepth = None
# effective maximum mixdepth to be used by joinmarket # effective maximum mixdepth to be used by joinmarket
@ -385,10 +383,13 @@ class BaseWallet(object):
# {script: path}, should always hold mappings for all "known" keys # {script: path}, should always hold mappings for all "known" keys
self._script_map = {} self._script_map = {}
# {address: path}, should always hold mappings for all "known" keys
self._addr_map = {}
self._load_storage() self._load_storage()
assert self._utxos is not None assert self._utxos is not None
assert self._cache is not None
assert self.max_mixdepth is not None assert self.max_mixdepth is not None
assert self.max_mixdepth >= 0 assert self.max_mixdepth >= 0
assert self.network in ('mainnet', 'testnet', 'signet') assert self.network in ('mainnet', 'testnet', 'signet')
@ -425,6 +426,7 @@ class BaseWallet(object):
self.network = self._storage.data[b'network'].decode('ascii') self.network = self._storage.data[b'network'].decode('ascii')
self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._utxos = UTXOManager(self._storage, self.merge_algorithm)
self._addr_labels = AddressLabelsManager(self._storage) self._addr_labels = AddressLabelsManager(self._storage)
self._cache = self._storage.data.setdefault(b'cache', {})
def get_storage_location(self): def get_storage_location(self):
""" Return the location of the """ Return the location of the
@ -538,34 +540,24 @@ class BaseWallet(object):
""" """
There should be no reason for code outside the wallet to need a privkey. There should be no reason for code outside the wallet to need a privkey.
""" """
script = self._ENGINE.address_to_script(addr) path = self.addr_to_path(addr)
path = self.script_to_path(script)
privkey = self._get_key_from_path(path)[0] privkey = self._get_key_from_path(path)[0]
return privkey return privkey
def _get_addr_int_ext(self, address_type, mixdepth):
if address_type == self.ADDRESS_TYPE_EXTERNAL:
script = self.get_external_script(mixdepth)
elif address_type == self.ADDRESS_TYPE_INTERNAL:
script = self.get_internal_script(mixdepth)
else:
assert 0
return self.script_to_addr(script)
def get_external_addr(self, mixdepth): def get_external_addr(self, mixdepth):
""" """
Return an address suitable for external distribution, including funding Return an address suitable for external distribution, including funding
the wallet from other sources, or receiving payments or donations. the wallet from other sources, or receiving payments or donations.
JoinMarket will never generate these addresses for internal use. JoinMarket will never generate these addresses for internal use.
""" """
return self._get_addr_int_ext(self.ADDRESS_TYPE_EXTERNAL, mixdepth) return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_EXTERNAL)
def get_internal_addr(self, mixdepth): def get_internal_addr(self, mixdepth):
""" """
Return an address for internal usage, as change addresses and when Return an address for internal usage, as change addresses and when
participating in transactions initiated by other parties. participating in transactions initiated by other parties.
""" """
return self._get_addr_int_ext(self.ADDRESS_TYPE_INTERNAL, mixdepth) return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_INTERNAL)
def get_external_script(self, mixdepth): def get_external_script(self, mixdepth):
return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL) return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL)
@ -575,21 +567,33 @@ class BaseWallet(object):
@classmethod @classmethod
def addr_to_script(cls, addr): def addr_to_script(cls, addr):
"""
Try not to call this slow method. Instead, call addr_to_path,
followed by get_script_from_path, as those are cached.
"""
return cls._ENGINE.address_to_script(addr) return cls._ENGINE.address_to_script(addr)
@classmethod @classmethod
def pubkey_to_script(cls, pubkey): def pubkey_to_script(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_script_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_script(pubkey) return cls._ENGINE.pubkey_to_script(pubkey)
@classmethod @classmethod
def pubkey_to_addr(cls, pubkey): def pubkey_to_addr(cls, pubkey):
"""
Try not to call this slow method. Instead, call
get_address_from_path if possible, as that is cached.
"""
return cls._ENGINE.pubkey_to_address(pubkey) return cls._ENGINE.pubkey_to_address(pubkey)
def script_to_addr(self, script): def script_to_addr(self, script,
assert self.is_known_script(script) validate_cache: bool = False):
path = self.script_to_path(script) path = self.script_to_path(script)
engine = self._get_key_from_path(path)[1] return self.get_address_from_path(path,
return engine.script_to_address(script) validate_cache=validate_cache)
def get_script_code(self, script): def get_script_code(self, script):
""" """
@ -600,8 +604,7 @@ class BaseWallet(object):
For non-segwit wallets, raises EngineError. For non-segwit wallets, raises EngineError.
""" """
path = self.script_to_path(script) path = self.script_to_path(script)
priv, engine = self._get_key_from_path(path) pub, engine = self._get_pubkey_from_path(path)
pub = engine.privkey_to_pubkey(priv)
return engine.pubkey_to_script_code(pub) return engine.pubkey_to_script_code(pub)
@classmethod @classmethod
@ -616,22 +619,42 @@ class BaseWallet(object):
def get_key(self, mixdepth, address_type, index): def get_key(self, mixdepth, address_type, index):
raise NotImplementedError() raise NotImplementedError()
def get_addr(self, mixdepth, address_type, index): def get_addr(self, mixdepth, address_type, index,
script = self.get_script(mixdepth, address_type, index) validate_cache: bool = False):
return self.script_to_addr(script) path = self.get_path(mixdepth, address_type, index)
return self.get_address_from_path(path,
def get_address_from_path(self, path): validate_cache=validate_cache)
script = self.get_script_from_path(path)
return self.script_to_addr(script) def get_address_from_path(self, path,
validate_cache: bool = False):
def get_new_addr(self, mixdepth, address_type): cache = self._get_cache_for_path(path)
addr = cache.get(b'A')
if addr is not None:
addr = addr.decode('ascii')
if addr is None or validate_cache:
engine = self._get_pubkey_from_path(path)[1]
script = self.get_script_from_path(path,
validate_cache=validate_cache)
new_addr = engine.script_to_address(script)
if addr is None:
addr = new_addr
cache[b'A'] = addr.encode('ascii')
elif addr != new_addr:
raise WalletError("Wallet cache validation failed")
return addr
def get_new_addr(self, mixdepth, address_type,
validate_cache: bool = True):
""" """
use get_external_addr/get_internal_addr use get_external_addr/get_internal_addr
""" """
script = self.get_new_script(mixdepth, address_type) script = self.get_new_script(mixdepth, address_type,
return self.script_to_addr(script) validate_cache=validate_cache)
return self.script_to_addr(script,
validate_cache=validate_cache)
def get_new_script(self, mixdepth, address_type): def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
raise NotImplementedError() raise NotImplementedError()
def get_wif(self, mixdepth, address_type, index): def get_wif(self, mixdepth, address_type, index):
@ -845,10 +868,19 @@ class BaseWallet(object):
confirmations, set maxheight to max acceptable blockheight. confirmations, set maxheight to max acceptable blockheight.
returns: {mixdepth: value} returns: {mixdepth: value}
""" """
balances = collections.defaultdict(int)
for md in range(self.mixdepth + 1):
balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose,
include_disabled=include_disabled, maxheight=maxheight)
return balances
def get_balance_at_mixdepth(self, mixdepth,
verbose: bool = True,
include_disabled: bool = False,
maxheight: Optional[int] = None) -> int:
# TODO: verbose # TODO: verbose
return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, return self._utxos.get_balance_at_mixdepth(mixdepth,
include_disabled=include_disabled, include_disabled=include_disabled, maxheight=maxheight)
maxheight=maxheight)
def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False):
""" """
@ -859,25 +891,35 @@ class BaseWallet(object):
{'script': bytes, 'path': tuple, 'value': int}}} {'script': bytes, 'path': tuple, 'value': int}}}
(if `includeheight` is True, adds key 'height': int) (if `includeheight` is True, adds key 'height': int)
""" """
mix_utxos = self._utxos.get_utxos_by_mixdepth()
script_utxos = collections.defaultdict(dict) script_utxos = collections.defaultdict(dict)
for md, data in mix_utxos.items(): for md in range(self.mixdepth + 1):
if md > self.mixdepth: script_utxos[md] = self.get_utxos_at_mixdepth(md,
continue include_disabled=include_disabled, includeheight=includeheight)
return script_utxos
def get_utxos_at_mixdepth(self, mixdepth: int,
include_disabled: bool = False,
includeheight: bool = False) -> \
Dict[Tuple[bytes, int], Dict[str, Any]]:
script_utxos = {}
if 0 <= mixdepth <= self.mixdepth:
data = self._utxos.get_utxos_at_mixdepth(mixdepth)
for utxo, (path, value, height) in data.items(): for utxo, (path, value, height) in data.items():
if not include_disabled and self._utxos.is_disabled(*utxo): if not include_disabled and self._utxos.is_disabled(*utxo):
continue continue
script = self.get_script_from_path(path) script = self.get_script_from_path(path)
addr = self.get_address_from_path(path) addr = self.get_address_from_path(path)
label = self.get_address_label(addr) label = self.get_address_label(addr)
script_utxos[md][utxo] = {'script': script, script_utxo = {
'path': path, 'script': script,
'value': value, 'path': path,
'address': addr, 'value': value,
'label': label} 'address': addr,
'label': label,
}
if includeheight: if includeheight:
script_utxos[md][utxo]['height'] = height script_utxo['height'] = height
script_utxos[utxo] = script_utxo
return script_utxos return script_utxos
@ -910,7 +952,8 @@ class BaseWallet(object):
def _get_mixdepth_from_path(self, path): def _get_mixdepth_from_path(self, path):
raise NotImplementedError() raise NotImplementedError()
def get_script_from_path(self, path): def get_script_from_path(self, path,
validate_cache: bool = False):
""" """
internal note: This is the final sink for all operations that somehow internal note: This is the final sink for all operations that somehow
need to derive a script. If anything goes wrong when deriving a need to derive a script. If anything goes wrong when deriving a
@ -921,15 +964,72 @@ class BaseWallet(object):
returns: returns:
script script
""" """
raise NotImplementedError() cache = self._get_cache_for_path(path)
script = cache.get(b'S')
if script is None or validate_cache:
pubkey, engine = self._get_pubkey_from_path(path,
validate_cache=validate_cache)
new_script = engine.pubkey_to_script(pubkey)
if script is None:
cache[b'S'] = script = new_script
elif script != new_script:
raise WalletError("Wallet cache validation failed")
return script
def get_script(self, mixdepth, address_type, index): def get_script(self, mixdepth, address_type, index,
validate_cache: bool = False):
path = self.get_path(mixdepth, address_type, index) path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path) return self.get_script_from_path(path, validate_cache=validate_cache)
def _get_key_from_path(self, path): def _get_key_from_path(self, path,
validate_cache: bool = False):
raise NotImplementedError() raise NotImplementedError()
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
privkey, engine = self._get_key_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, engine
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
privkey, pubkey, engine = self._get_keypair_from_path(path,
validate_cache=validate_cache)
return pubkey, engine
def _get_cache_keys_for_path(self, path):
return path[:1] + tuple(map(_int_to_bytestr, path[1:]))
def _get_cache_for_path(self, path):
assert len(path) > 0
cache = self._cache
for k in self._get_cache_keys_for_path(path):
cache = cache.setdefault(k, {})
return cache
def _delete_cache_for_path(self, path) -> bool:
assert len(path) > 0
def recurse(cache, itr):
k = next(itr, None)
if k is None:
cache.clear()
else:
child = cache.get(k)
if not child or not recurse(child, itr):
return False
if not child:
del cache[k]
return True
return recurse(self._cache, iter(self._get_cache_keys_for_path(path)))
def get_path_repr(self, path): def get_path_repr(self, path):
""" """
Get a human-readable representation of the wallet path. Get a human-readable representation of the wallet path.
@ -984,7 +1084,7 @@ class BaseWallet(object):
signature as base64-encoded string signature as base64-encoded string
""" """
priv, engine = self._get_key_from_path(path) priv, engine = self._get_key_from_path(path)
addr = engine.privkey_to_address(priv) addr = self.get_address_from_path(path)
return addr, engine.sign_message(priv, message) return addr, engine.sign_message(priv, message)
def get_wallet_name(self): def get_wallet_name(self):
@ -1038,8 +1138,8 @@ class BaseWallet(object):
returns: returns:
bool bool
""" """
script = self.addr_to_script(addr) assert isinstance(addr, str)
return script in self._script_map return addr in self._addr_map
def is_known_script(self, script): def is_known_script(self, script):
""" """
@ -1054,8 +1154,8 @@ class BaseWallet(object):
return script in self._script_map return script in self._script_map
def get_addr_mixdepth(self, addr): def get_addr_mixdepth(self, addr):
script = self.addr_to_script(addr) path = self.addr_to_path(addr)
return self.get_script_mixdepth(script) return self._get_mixdepth_from_path(path)
def get_script_mixdepth(self, script): def get_script_mixdepth(self, script):
path = self.script_to_path(script) path = self.script_to_path(script)
@ -1068,16 +1168,26 @@ class BaseWallet(object):
returns: returns:
path generator path generator
""" """
for s in self._script_map.values(): for md in range(self.max_mixdepth + 1):
yield s for path in self.yield_imported_paths(md):
yield path
def _populate_maps(self, paths):
for path in paths:
self._script_map[self.get_script_from_path(path)] = path
self._addr_map[self.get_address_from_path(path)] = path
def addr_to_path(self, addr): def addr_to_path(self, addr):
script = self.addr_to_script(addr) assert isinstance(addr, str)
return self.script_to_path(script) path = self._addr_map.get(addr)
assert path is not None
return path
def script_to_path(self, script): def script_to_path(self, script):
assert script in self._script_map assert isinstance(script, bytes)
return self._script_map[script] path = self._script_map.get(script)
assert path is not None
return path
def set_next_index(self, mixdepth, address_type, index, force=False): def set_next_index(self, mixdepth, address_type, index, force=False):
""" """
@ -1379,9 +1489,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in # this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script. # this wallet; in this case, we cannot set the redeem script.
continue continue
privkey, _ = self._get_key_from_path(path) pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script( txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
btc.privkey_to_pubkey(privkey))
return new_psbt return new_psbt
def sign_psbt(self, in_psbt, with_sign_result=False): def sign_psbt(self, in_psbt, with_sign_result=False):
@ -1451,9 +1560,8 @@ class PSBTWalletMixin(object):
# this happens when an input is provided but it's not in # this happens when an input is provided but it's not in
# this wallet; in this case, we cannot set the redeem script. # this wallet; in this case, we cannot set the redeem script.
continue continue
privkey, _ = self._get_key_from_path(path) pubkey = self._get_pubkey_from_path(path)[0]
txinput.redeem_script = btc.pubkey_to_p2wpkh_script( txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey)
btc.privkey_to_pubkey(privkey))
# no else branch; any other form of scriptPubKey will just be # no else branch; any other form of scriptPubKey will just be
# ignored. # ignored.
try: try:
@ -1767,12 +1875,7 @@ class ImportWalletMixin(object):
for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items():
md = int(md) md = int(md)
self._imported[md] = keys self._imported[md] = keys
for index, (key, key_type) in enumerate(keys): self._populate_maps(self.yield_imported_paths(md))
if not key:
# imported key was removed
continue
assert key_type in self._ENGINES
self._cache_imported_key(md, key, key_type, index)
def save(self): def save(self):
import_data = {} import_data = {}
@ -1841,8 +1944,8 @@ class ImportWalletMixin(object):
raise Exception("Only one of script|address|path may be given.") raise Exception("Only one of script|address|path may be given.")
if address: if address:
script = self.addr_to_script(address) path = self.addr_to_path(address)
if script: elif script:
path = self.script_to_path(script) path = self.script_to_path(script)
if not path: if not path:
@ -1855,18 +1958,19 @@ class ImportWalletMixin(object):
if not script: if not script:
script = self.get_script_from_path(path) script = self.get_script_from_path(path)
if not address:
address = self.get_address_from_path(path)
# we need to retain indices # we need to retain indices
self._imported[path[1]][path[2]] = (b'', -1) self._imported[path[1]][path[2]] = (b'', -1)
del self._script_map[script] del self._script_map[script]
del self._addr_map[address]
self._delete_cache_for_path(path)
def _cache_imported_key(self, mixdepth, privkey, key_type, index): def _cache_imported_key(self, mixdepth, privkey, key_type, index):
engine = self._ENGINES[key_type]
path = (self._IMPORTED_ROOT_PATH, mixdepth, index) path = (self._IMPORTED_ROOT_PATH, mixdepth, index)
self._populate_maps((path,))
self._script_map[engine.key_to_script(privkey)] = path
return path return path
def _get_mixdepth_from_path(self, path): def _get_mixdepth_from_path(self, path):
@ -1876,9 +1980,11 @@ class ImportWalletMixin(object):
assert len(path) == 3 assert len(path) == 3
return path[1] return path[1]
def _get_key_from_path(self, path): def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_imported_path(path): if not self._is_imported_path(path):
return super()._get_key_from_path(path) return super()._get_key_from_path(path,
validate_cache=validate_cache)
assert len(path) == 3 assert len(path) == 3
md, i = path[1], path[2] md, i = path[1], path[2]
@ -1901,7 +2007,7 @@ class ImportWalletMixin(object):
def is_standard_wallet_script(self, path): def is_standard_wallet_script(self, path):
if self._is_imported_path(path): if self._is_imported_path(path):
engine = self._get_key_from_path(path)[1] engine = self._get_pubkey_from_path(path)[1]
return engine == self._ENGINE return engine == self._ENGINE
return super().is_standard_wallet_script(path) return super().is_standard_wallet_script(path)
@ -1932,13 +2038,6 @@ class ImportWalletMixin(object):
return super().get_details(path) return super().get_details(path)
return path[1], 'imported', path[2] return path[1], 'imported', path[2]
def get_script_from_path(self, path):
if not self._is_imported_path(path):
return super().get_script_from_path(path)
priv, engine = self._get_key_from_path(path)
return engine.key_to_script(priv)
class BIP39WalletMixin(object): class BIP39WalletMixin(object):
""" """
@ -2009,6 +2108,7 @@ class BIP32Wallet(BaseWallet):
def __init__(self, storage, **kwargs): def __init__(self, storage, **kwargs):
self._entropy = None self._entropy = None
self._key_ident = None
# {mixdepth: {type: index}} with type being 0/1 corresponding # {mixdepth: {type: index}} with type being 0/1 corresponding
# to external/internal addresses # to external/internal addresses
self._index_cache = None self._index_cache = None
@ -2027,7 +2127,7 @@ class BIP32Wallet(BaseWallet):
# used to verify paths for sanity checking and for wallet id creation # used to verify paths for sanity checking and for wallet id creation
self._key_ident = b'' # otherwise get_bip32_* won't work self._key_ident = b'' # otherwise get_bip32_* won't work
self._key_ident = self._get_key_ident() self._key_ident = self._get_key_ident()
self._populate_script_map() self._populate_maps(self.yield_known_bip32_paths())
self.disable_new_scripts = False self.disable_new_scripts = False
@classmethod @classmethod
@ -2073,13 +2173,14 @@ class BIP32Wallet(BaseWallet):
self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\ self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\
.digest()[:3] .digest()[:3]
def _populate_script_map(self): def yield_known_paths(self):
return chain(super().yield_known_paths(), self.yield_known_bip32_paths())
def yield_known_bip32_paths(self):
for md in self._index_cache: for md in self._index_cache:
for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID):
for i in range(self._index_cache[md][address_type]): for i in range(self._index_cache[md][address_type]):
path = self.get_path(md, address_type, i) yield self.get_path(md, address_type, i)
script = self.get_script_from_path(path)
self._script_map[script] = path
def save(self): def save(self):
for md, data in self._index_cache.items(): for md, data in self._index_cache.items():
@ -2114,10 +2215,7 @@ class BIP32Wallet(BaseWallet):
def _get_supported_address_types(cls): def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID) return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID)
def get_script_from_path(self, path): def _check_path(self, path):
if not self._is_my_bip32_path(path):
raise WalletError("unable to get script for unknown key path")
md, address_type, index = self.get_details(path) md, address_type, index = self.get_details(path)
if not 0 <= md <= self.max_mixdepth: if not 0 <= md <= self.max_mixdepth:
@ -2130,12 +2228,22 @@ class BIP32Wallet(BaseWallet):
and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID: and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID:
#special case for timelocked addresses because for them the #special case for timelocked addresses because for them the
#concept of a "next address" cant be used #concept of a "next address" cant be used
return self.get_new_script_override_disable(md, address_type) self._set_index_cache(md, address_type, current_index + 1)
self._populate_maps((path,))
priv, engine = self._get_key_from_path(path)
script = engine.key_to_script(priv) def get_script_from_path(self, path,
validate_cache: bool = False):
return script if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_script_from_path(path,
validate_cache=validate_cache)
def get_address_from_path(self, path,
validate_cache: bool = False):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_address_from_path(path,
validate_cache=validate_cache)
def get_path(self, mixdepth=None, address_type=None, index=None): def get_path(self, mixdepth=None, address_type=None, index=None):
if mixdepth is not None: if mixdepth is not None:
@ -2151,7 +2259,6 @@ class BIP32Wallet(BaseWallet):
assert isinstance(index, Integral) assert isinstance(index, Integral)
if address_type is None: if address_type is None:
raise Exception("address_type must be set if index is set") raise Exception("address_type must be set if index is set")
assert index <= self._index_cache[mixdepth][address_type]
assert index < self.BIP32_MAX_PATH_LEVEL assert index < self.BIP32_MAX_PATH_LEVEL
return tuple(chain(self._get_bip32_export_path(mixdepth, address_type), return tuple(chain(self._get_bip32_export_path(mixdepth, address_type),
(index,))) (index,)))
@ -2200,30 +2307,62 @@ class BIP32Wallet(BaseWallet):
return path[len(self._get_bip32_base_path())] return path[len(self._get_bip32_base_path())]
def _get_key_from_path(self, path): def _get_key_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path): if not self._is_my_bip32_path(path):
raise WalletError("Invalid path, unknown root: {}".format(path)) raise WalletError("Invalid path, unknown root: {}".format(path))
cache = self._get_cache_for_path(path)
return self._ENGINE.derive_bip32_privkey(self._master_key, path), \ privkey = cache.get(b'p')
self._ENGINE if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return privkey, self._ENGINE
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
cache = self._get_cache_for_path(path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return privkey, pubkey, self._ENGINE
def _get_cache_keys_for_path(self, path):
if not self._is_my_bip32_path(path):
return super()._get_cache_keys_for_path(path)
return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii')
for lvl in path[1:]])
def _is_my_bip32_path(self, path): def _is_my_bip32_path(self, path):
return path[0] == self._key_ident return len(path) > 0 and path[0] == self._key_ident
def is_standard_wallet_script(self, path): def is_standard_wallet_script(self, path):
return self._is_my_bip32_path(path) return self._is_my_bip32_path(path)
def get_new_script(self, mixdepth, address_type): def get_new_script(self, mixdepth, address_type,
validate_cache: bool = True):
if self.disable_new_scripts: if self.disable_new_scripts:
raise RuntimeError("Obtaining new wallet addresses " raise RuntimeError("Obtaining new wallet addresses "
+ "disabled, due to nohistory mode") + "disabled, due to nohistory mode")
return self.get_new_script_override_disable(mixdepth, address_type) index = self._index_cache[mixdepth][address_type]
return self.get_script(mixdepth, address_type, index,
def get_new_script_override_disable(self, mixdepth, address_type): validate_cache=validate_cache)
# This is called by get_script_from_path and calls back there. We need to
# ensure all conditions match to avoid endless recursion.
index = self.get_index_cache_and_increment(mixdepth, address_type)
return self.get_script_and_update_map(mixdepth, address_type, index)
def _set_index_cache(self, mixdepth, address_type, index): def _set_index_cache(self, mixdepth, address_type, index):
""" Ensures that any update to index_cache dict only applies """ Ensures that any update to index_cache dict only applies
@ -2232,22 +2371,6 @@ class BIP32Wallet(BaseWallet):
assert address_type in self._get_supported_address_types() assert address_type in self._get_supported_address_types()
self._index_cache[mixdepth][address_type] = index self._index_cache[mixdepth][address_type] = index
def get_index_cache_and_increment(self, mixdepth, address_type):
index = self._index_cache[mixdepth][address_type]
cur_index = self._index_cache[mixdepth][address_type]
self._set_index_cache(mixdepth, address_type, cur_index + 1)
return cur_index
def get_script_and_update_map(self, *args):
path = self.get_path(*args)
script = self.get_script_from_path(path)
self._script_map[script] = path
return script
def get_script(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
@deprecated @deprecated
def get_key(self, mixdepth, address_type, index): def get_key(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index) path = self.get_path(mixdepth, address_type, index)
@ -2392,6 +2515,10 @@ class FidelityBondMixin(object):
_BIP32_PUBKEY_PREFIX = "fbonds-mpk-" _BIP32_PUBKEY_PREFIX = "fbonds-mpk-"
def __init__(self, storage, **kwargs):
super().__init__(storage, **kwargs)
self._populate_maps(self.yield_fidelity_bond_paths())
@classmethod @classmethod
def _time_number_to_timestamp(cls, timenumber): def _time_number_to_timestamp(cls, timenumber):
""" """
@ -2435,8 +2562,7 @@ class FidelityBondMixin(object):
def _get_key_ident(self): def _get_key_ident(self):
first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID) first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID)
priv, engine = self._get_key_from_path(first_path) pub = self._get_pubkey_from_path(first_path)[0]
pub = engine.privkey_to_pubkey(priv)
return sha256(sha256(pub).digest()).digest()[:3] return sha256(sha256(pub).digest()).digest()[:3]
def is_standard_wallet_script(self, path): def is_standard_wallet_script(self, path):
@ -2451,14 +2577,14 @@ class FidelityBondMixin(object):
else: else:
return False return False
def _populate_script_map(self): def yield_known_paths(self):
super()._populate_script_map() return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths())
def yield_fidelity_bond_paths(self):
md = self.FIDELITY_BOND_MIXDEPTH md = self.FIDELITY_BOND_MIXDEPTH
address_type = self.BIP32_TIMELOCK_ID address_type = self.BIP32_TIMELOCK_ID
for timenumber in range(self.TIMENUMBER_COUNT): for timenumber in range(self.TIMENUMBER_COUNT):
path = self.get_path(md, address_type, timenumber) yield self.get_path(md, address_type, timenumber)
script = self.get_script_from_path(path)
self._script_map[script] = path
def add_utxo(self, txid, index, script, value, height=None): def add_utxo(self, txid, index, script, value, height=None):
super().add_utxo(txid, index, script, value, height) super().add_utxo(txid, index, script, value, height)
@ -2482,16 +2608,54 @@ class FidelityBondMixin(object):
def _get_supported_address_types(cls): def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID) return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID)
def _get_key_from_path(self, path): def _get_key_from_path(self, path,
validate_cache: bool = False):
if self.is_timelocked_path(path): if self.is_timelocked_path(path):
key_path = path[:-1] key_path = path[:-1]
locktime = path[-1] locktime = path[-1]
engine = self._TIMELOCK_ENGINE engine = self._TIMELOCK_ENGINE
privkey = engine.derive_bip32_privkey(self._master_key, key_path) cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), engine return (privkey, locktime), engine
else: else:
return super()._get_key_from_path(path) return super()._get_key_from_path(path)
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
if not self.is_timelocked_path(path):
return super()._get_keypair_from_path(path,
validate_cache=validate_cache)
key_path = path[:-1]
locktime = path[-1]
engine = self._TIMELOCK_ENGINE
cache = super()._get_cache_for_path(key_path)
privkey = cache.get(b'p')
if privkey is None or validate_cache:
new_privkey = engine.derive_bip32_privkey(self._master_key, key_path)
if privkey is None:
cache[b'p'] = privkey = new_privkey
elif privkey != new_privkey:
raise WalletError("Wallet cache validation failed")
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = engine.privkey_to_pubkey(privkey)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (privkey, locktime), (pubkey, locktime), engine
def _get_cache_for_path(self, path):
if self.is_timelocked_path(path):
path = path[:-1]
return super()._get_cache_for_path(path)
def get_path(self, mixdepth=None, address_type=None, index=None): def get_path(self, mixdepth=None, address_type=None, index=None):
if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID, if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID,
self.BIP32_BURN_ID) or index == None: self.BIP32_BURN_ID) or index == None:
@ -2537,14 +2701,6 @@ class FidelityBondMixin(object):
def _get_default_used_indices(self): def _get_default_used_indices(self):
return {x: [0, 0, 0, 0] for x in range(self.max_mixdepth + 1)} return {x: [0, 0, 0, 0] for x in range(self.max_mixdepth + 1)}
def get_script(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
return self.get_script_from_path(path)
def get_addr(self, mixdepth, address_type, index):
script = self.get_script(mixdepth, address_type, index)
return self.script_to_addr(script)
def add_burner_output(self, path, txhex, block_height, merkle_branch, def add_burner_output(self, path, txhex, block_height, merkle_branch,
block_index, write=True): block_index, write=True):
""" """
@ -2644,6 +2800,43 @@ class FidelityBondWatchonlyWallet(FidelityBondMixin, BIP84Wallet):
path = super()._get_bip32_export_path(mixdepth, address_type) path = super()._get_bip32_export_path(mixdepth, address_type)
return path return path
def _get_key_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_keypair_from_path(self, path,
validate_cache: bool = False):
raise WalletError("Cannot get a private key from a watch-only wallet")
def _get_pubkey_from_path(self, path,
validate_cache: bool = False):
if not self._is_my_bip32_path(path):
return super()._get_pubkey_from_path(path,
validate_cache=validate_cache)
if self.is_timelocked_path(path):
key_path = path[:-1]
locktime = path[-1]
cache = self._get_cache_for_path(key_path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey(
self._master_key, key_path)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return (pubkey, locktime), self._TIMELOCK_ENGINE
cache = self._get_cache_for_path(path)
pubkey = cache.get(b'P')
if pubkey is None or validate_cache:
new_pubkey = self._ENGINE.derive_bip32_privkey(
self._master_key, path)
if pubkey is None:
cache[b'P'] = pubkey = new_pubkey
elif pubkey != new_pubkey:
raise WalletError("Wallet cache validation failed")
return pubkey, self._ENGINE
WALLET_IMPLEMENTATIONS = { WALLET_IMPLEMENTATIONS = {
LegacyWallet.TYPE: LegacyWallet, LegacyWallet.TYPE: LegacyWallet,

23
src/jmclient/wallet_utils.py

@ -7,7 +7,7 @@ import sys
from datetime import datetime, timedelta from datetime import datetime, timedelta
from optparse import OptionParser from optparse import OptionParser
from numbers import Integral from numbers import Integral
from collections import Counter from collections import Counter, defaultdict
from itertools import islice, chain from itertools import islice, chain
from jmclient import (get_network, WALLET_IMPLEMENTATIONS, Storage, podle, from jmclient import (get_network, WALLET_IMPLEMENTATIONS, Storage, podle,
jm_single, WalletError, BaseWallet, VolatileStorage, jm_single, WalletError, BaseWallet, VolatileStorage,
@ -403,15 +403,15 @@ def get_tx_info(txid, tx_cache=None):
def get_imported_privkey_branch(wallet_service, m, showprivkey): def get_imported_privkey_branch(wallet_service, m, showprivkey):
entries = [] entries = []
balance_by_script = defaultdict(int)
for data in wallet_service.get_utxos_at_mixdepth(m,
include_disabled=True).values():
balance_by_script[data['script']] += data['value']
for path in wallet_service.yield_imported_paths(m): for path in wallet_service.yield_imported_paths(m):
addr = wallet_service.get_address_from_path(path) addr = wallet_service.get_address_from_path(path)
script = wallet_service.get_script_from_path(path) script = wallet_service.get_script_from_path(path)
balance = 0.0 balance = balance_by_script.get(script, 0)
for data in wallet_service.get_utxos_by_mixdepth( status = ('used' if balance else 'empty')
include_disabled=True)[m].values():
if script == data['script']:
balance += data['value']
status = ('used' if balance > 0.0 else 'empty')
if showprivkey: if showprivkey:
wip_privkey = wallet_service.get_wif_path(path) wip_privkey = wallet_service.get_wif_path(path)
else: else:
@ -431,9 +431,6 @@ def wallet_showutxos(wallet_service, showprivkey):
includeconfs=True) includeconfs=True)
for md in utxos: for md in utxos:
(enabled, disabled) = get_utxos_enabled_disabled(wallet_service, md) (enabled, disabled) = get_utxos_enabled_disabled(wallet_service, md)
utxo_d = []
for k, v in disabled.items():
utxo_d.append(k)
for u, av in utxos[md].items(): for u, av in utxos[md].items():
success, us = utxo_to_utxostr(u) success, us = utxo_to_utxostr(u)
assert success assert success
@ -453,7 +450,7 @@ def wallet_showutxos(wallet_service, showprivkey):
'external': False, 'external': False,
'mixdepth': mixdepth, 'mixdepth': mixdepth,
'confirmations': av['confs'], 'confirmations': av['confs'],
'frozen': True if u in utxo_d else False} 'frozen': u in disabled}
if showprivkey: if showprivkey:
unsp[us]['privkey'] = wallet_service.get_wif_path(av['path']) unsp[us]['privkey'] = wallet_service.get_wif_path(av['path'])
if locktime: if locktime:
@ -1279,8 +1276,8 @@ def display_utxos_for_disable_choice_default(wallet_service, utxos_enabled,
def get_utxos_enabled_disabled(wallet_service, md): def get_utxos_enabled_disabled(wallet_service, md):
""" Returns dicts for enabled and disabled separately """ Returns dicts for enabled and disabled separately
""" """
utxos_enabled = wallet_service.get_utxos_by_mixdepth()[md] utxos_enabled = wallet_service.get_utxos_at_mixdepth(md)
utxos_all = wallet_service.get_utxos_by_mixdepth(include_disabled=True)[md] utxos_all = wallet_service.get_utxos_at_mixdepth(md, include_disabled=True)
utxos_disabled_keyset = set(utxos_all).difference(set(utxos_enabled)) utxos_disabled_keyset = set(utxos_all).difference(set(utxos_enabled))
utxos_disabled = {} utxos_disabled = {}
for u in utxos_disabled_keyset: for u in utxos_disabled_keyset:

16
test/jmclient/test_taker.py

@ -121,6 +121,12 @@ class DummyWallet(LegacyWallet):
""" """
return 'p2wpkh' return 'p2wpkh'
def _get_key_from_path(self, path,
validate_cache: bool = False):
if path[0] == b'dummy':
return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE
raise NotImplementedError()
def get_key_from_addr(self, addr): def get_key_from_addr(self, addr):
"""usable addresses: privkey all 1s, 2s, 3s, ... :""" """usable addresses: privkey all 1s, 2s, 3s, ... :"""
privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]] privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]]
@ -139,18 +145,20 @@ class DummyWallet(LegacyWallet):
return p return p
raise ValueError("No such keypair") raise ValueError("No such keypair")
def _is_my_bip32_path(self, path): def get_path_repr(self, path):
return True return '/'.join(map(str, path))
def is_standard_wallet_script(self, path): def is_standard_wallet_script(self, path):
if path[0] == "nonstandard_path": if path[0] == "nonstandard_path":
return False return False
return True return True
def script_to_addr(self, script): def script_to_addr(self, script,
validate_cache: bool = False):
if self.script_to_path(script)[0] == "nonstandard_path": if self.script_to_path(script)[0] == "nonstandard_path":
return "dummyaddr" return "dummyaddr"
return super().script_to_addr(script) return super().script_to_addr(script,
validate_cache=validate_cache)
def dummy_order_chooser(): def dummy_order_chooser():

24
test/jmclient/test_utxomanager.py

@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert not um.is_disabled(txid, index+2) assert not um.is_disabled(txid, index+2)
um.disable_utxo(txid, index+2) um.disable_utxo(txid, index+2)
utxos = um.get_utxos_by_mixdepth() assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1
assert len(utxos[mixdepth]) == 1 assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2
assert len(utxos[mixdepth+1]) == 2 assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0
assert len(utxos[mixdepth+2]) == 0
balances = um.get_balance_by_mixdepth() assert um.get_balance_at_mixdepth(mixdepth) == value
assert balances[mixdepth] == value assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2
assert balances[mixdepth+1] == value * 2
um.remove_utxo(txid, index, mixdepth) um.remove_utxo(txid, index, mixdepth)
assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index) == False
@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index) == False
assert um.have_utxo(txid, index+1) == mixdepth + 1 assert um.have_utxo(txid, index+1) == mixdepth + 1
utxos = um.get_utxos_by_mixdepth() assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0
assert len(utxos[mixdepth]) == 0 assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1
assert len(utxos[mixdepth+1]) == 1
balances = um.get_balance_by_mixdepth() assert um.get_balance_at_mixdepth(mixdepth) == 0
assert balances[mixdepth] == 0 assert um.get_balance_at_mixdepth(mixdepth+1) == value
assert balances[mixdepth+1] == value assert um.get_balance_at_mixdepth(mixdepth+2) == 0
assert balances[mixdepth+2] == 0
def test_utxomanager_select(setup_env_nodeps): def test_utxomanager_select(setup_env_nodeps):

14
test/jmclient/test_wallet.py

@ -17,7 +17,6 @@ from jmclient import load_test_config, jm_single, BaseWallet, \
wallet_gettimelockaddress, UnknownAddressForLabel wallet_gettimelockaddress, UnknownAddressForLabel
from test_blockchaininterface import sync_test_wallet from test_blockchaininterface import sync_test_wallet
from freezegun import freeze_time from freezegun import freeze_time
from bitcointx.wallet import CCoinAddressError
pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind")
@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif):
mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
#wallet needs to know about the script beforehand
wallet.get_script_and_update_map(mixdepth, address_type, timenumber)
assert address == wallet.get_addr(mixdepth, address_type, timenumber) assert address == wallet.get_addr(mixdepth, address_type, timenumber)
assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber)) assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber))
@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string):
m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
script = wallet.get_script_and_update_map(m, address_type, timenumber) script = wallet.get_script(m, address_type, timenumber)
addr = wallet.script_to_addr(script) addr = wallet.script_to_addr(script)
addr_from_method = wallet_gettimelockaddress(wallet, locktime_string) addr_from_method = wallet_gettimelockaddress(wallet, locktime_string)
@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet):
wallet = SegwitWalletFidelityBonds(storage) wallet = SegwitWalletFidelityBonds(storage)
timenumber = 0 timenumber = 0
script = wallet.get_script_and_update_map( script = wallet.get_script(
FidelityBondMixin.FIDELITY_BOND_MIXDEPTH, FidelityBondMixin.FIDELITY_BOND_MIXDEPTH,
FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber) FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber)
utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script)) utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script))
@ -477,7 +473,7 @@ def test_get_bbm(setup_wallet):
wallet = get_populated_wallet(amount, num_tx) wallet = get_populated_wallet(amount, num_tx)
# disable a utxo and check we can correctly report # disable a utxo and check we can correctly report
# balance with the disabled flag off: # balance with the disabled flag off:
utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0]
wallet.disable_utxo(*utxo_1) wallet.disable_utxo(*utxo_1)
balances = wallet.get_balance_by_mixdepth(include_disabled=True) balances = wallet.get_balance_by_mixdepth(include_disabled=True)
assert balances[0] == num_tx * amount assert balances[0] == num_tx * amount
@ -610,7 +606,9 @@ def test_address_labels(setup_wallet):
wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4") wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4")
wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4", wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4",
"test") "test")
with pytest.raises(CCoinAddressError): # we no longer decode addresses just to see if we know about them,
# so we won't get a CCoinAddressError for invalid addresses
#with pytest.raises(CCoinAddressError):
wallet.get_address_label("badaddress") wallet.get_address_label("badaddress")
wallet.set_address_label("badaddress", "test") wallet.set_address_label("badaddress", "test")

Loading…
Cancel
Save