Browse Source

wallet: add get_{balance,utxos}_at_mixdepth methods

Sometimes calling code is only interested in the balance or UTXOs at a
single mixdepth. In these cases, it is wasteful to get the balance or
UTXOs at all mixdepths, only to throw away the returned information
about all but the single mixdepth of interest. Implement new methods in
BaseWallet to get the balance or UTXOs at a single mixdepth.

Also, correct an apparent oversight due to apparently misplaced
indentation: the maxheight parameter of get_balance_by_mixdepth was
ignored unless the include_disabled parameter was passed as False. It
appears that the intention was for include_disabled and maxheight to be
independent filters on the returned information.
master
Matt Whitlock 2 years ago
parent
commit
184d76f7f7
  1. 90
      src/jmclient/wallet.py
  2. 24
      test/jmclient/test_utxomanager.py
  3. 2
      test/jmclient/test_wallet.py

90
src/jmclient/wallet.py

@ -19,6 +19,7 @@ from itertools import chain
from decimal import Decimal from decimal import Decimal
from numbers import Integral from numbers import Integral
from math import exp from math import exp
from typing import Any, Dict, Optional, Tuple
from .configure import jm_single from .configure import jm_single
@ -280,32 +281,28 @@ class UTXOManager(object):
'value': utxos[s['utxo']][1]} 'value': utxos[s['utxo']][1]}
for s in selected} for s in selected}
def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), def get_balance_at_mixdepth(self, mixdepth: int,
include_disabled=True, maxheight=None): include_disabled: bool = True,
""" By default this returns a dict of aggregated bitcoin maxheight: Optional[int] = None) -> int:
balance per mixdepth: {0: N sats, 1: M sats, ...} for all """ By default this returns aggregated bitcoin balance at mixdepth.
currently available mixdepths.
If max_mixdepth is set it will return balances only up
to that mixdepth.
To get only enabled balance, set include_disabled=False. To get only enabled balance, set include_disabled=False.
To get balances only with a certain number of confs, use maxheight. To get balances only with a certain number of confs, use maxheight.
""" """
balance_dict = collections.defaultdict(int) utxomap = self._utxo.get(mixdepth)
for mixdepth, utxomap in self._utxo.items(): if not utxomap:
if mixdepth > max_mixdepth: return 0
continue if not include_disabled:
if not include_disabled: utxomap = {k: v for k, v in utxomap.items(
utxomap = {k: v for k, v in utxomap.items( ) if not self.is_disabled(*k)}
) if not self.is_disabled(*k)} if maxheight is not None:
if maxheight is not None: utxomap = {k: v for k, v in utxomap.items(
utxomap = {k: v for k, v in utxomap.items( ) if v[2] <= maxheight}
) if v[2] <= maxheight} return sum(x[1] for x in utxomap.values())
value = sum(x[1] for x in utxomap.values())
balance_dict[mixdepth] = value def get_utxos_at_mixdepth(self, mixdepth: int) -> \
return balance_dict Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]:
utxomap = self._utxo.get(mixdepth)
def get_utxos_by_mixdepth(self): return deepcopy(utxomap) if utxomap else {}
return deepcopy(self._utxo)
def __eq__(self, o): def __eq__(self, o):
return self._utxo == o._utxo and \ return self._utxo == o._utxo and \
@ -836,10 +833,19 @@ class BaseWallet(object):
confirmations, set maxheight to max acceptable blockheight. confirmations, set maxheight to max acceptable blockheight.
returns: {mixdepth: value} returns: {mixdepth: value}
""" """
balances = collections.defaultdict(int)
for md in range(self.mixdepth + 1):
balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose,
include_disabled=include_disabled, maxheight=maxheight)
return balances
def get_balance_at_mixdepth(self, mixdepth,
verbose: bool = True,
include_disabled: bool = False,
maxheight: Optional[int] = None) -> int:
# TODO: verbose # TODO: verbose
return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, return self._utxos.get_balance_at_mixdepth(mixdepth,
include_disabled=include_disabled, include_disabled=include_disabled, maxheight=maxheight)
maxheight=maxheight)
def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False):
""" """
@ -850,25 +856,35 @@ class BaseWallet(object):
{'script': bytes, 'path': tuple, 'value': int}}} {'script': bytes, 'path': tuple, 'value': int}}}
(if `includeheight` is True, adds key 'height': int) (if `includeheight` is True, adds key 'height': int)
""" """
mix_utxos = self._utxos.get_utxos_by_mixdepth()
script_utxos = collections.defaultdict(dict) script_utxos = collections.defaultdict(dict)
for md, data in mix_utxos.items(): for md in range(self.mixdepth + 1):
if md > self.mixdepth: script_utxos[md] = self.get_utxos_at_mixdepth(md,
continue include_disabled=include_disabled, includeheight=includeheight)
return script_utxos
def get_utxos_at_mixdepth(self, mixdepth: int,
include_disabled: bool = False,
includeheight: bool = False) -> \
Dict[Tuple[bytes, int], Dict[str, Any]]:
script_utxos = {}
if 0 <= mixdepth <= self.mixdepth:
data = self._utxos.get_utxos_at_mixdepth(mixdepth)
for utxo, (path, value, height) in data.items(): for utxo, (path, value, height) in data.items():
if not include_disabled and self._utxos.is_disabled(*utxo): if not include_disabled and self._utxos.is_disabled(*utxo):
continue continue
script = self.get_script_from_path(path) script = self.get_script_from_path(path)
addr = self.get_address_from_path(path) addr = self.get_address_from_path(path)
label = self.get_address_label(addr) label = self.get_address_label(addr)
script_utxos[md][utxo] = {'script': script, script_utxo = {
'path': path, 'script': script,
'value': value, 'path': path,
'address': addr, 'value': value,
'label': label} 'address': addr,
'label': label,
}
if includeheight: if includeheight:
script_utxos[md][utxo]['height'] = height script_utxo['height'] = height
script_utxos[utxo] = script_utxo
return script_utxos return script_utxos

24
test/jmclient/test_utxomanager.py

@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert not um.is_disabled(txid, index+2) assert not um.is_disabled(txid, index+2)
um.disable_utxo(txid, index+2) um.disable_utxo(txid, index+2)
utxos = um.get_utxos_by_mixdepth() assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1
assert len(utxos[mixdepth]) == 1 assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2
assert len(utxos[mixdepth+1]) == 2 assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0
assert len(utxos[mixdepth+2]) == 0
balances = um.get_balance_by_mixdepth() assert um.get_balance_at_mixdepth(mixdepth) == value
assert balances[mixdepth] == value assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2
assert balances[mixdepth+1] == value * 2
um.remove_utxo(txid, index, mixdepth) um.remove_utxo(txid, index, mixdepth)
assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index) == False
@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps):
assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index) == False
assert um.have_utxo(txid, index+1) == mixdepth + 1 assert um.have_utxo(txid, index+1) == mixdepth + 1
utxos = um.get_utxos_by_mixdepth() assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0
assert len(utxos[mixdepth]) == 0 assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1
assert len(utxos[mixdepth+1]) == 1
balances = um.get_balance_by_mixdepth() assert um.get_balance_at_mixdepth(mixdepth) == 0
assert balances[mixdepth] == 0 assert um.get_balance_at_mixdepth(mixdepth+1) == value
assert balances[mixdepth+1] == value assert um.get_balance_at_mixdepth(mixdepth+2) == 0
assert balances[mixdepth+2] == 0
def test_utxomanager_select(setup_env_nodeps): def test_utxomanager_select(setup_env_nodeps):

2
test/jmclient/test_wallet.py

@ -477,7 +477,7 @@ def test_get_bbm(setup_wallet):
wallet = get_populated_wallet(amount, num_tx) wallet = get_populated_wallet(amount, num_tx)
# disable a utxo and check we can correctly report # disable a utxo and check we can correctly report
# balance with the disabled flag off: # balance with the disabled flag off:
utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0]
wallet.disable_utxo(*utxo_1) wallet.disable_utxo(*utxo_1)
balances = wallet.get_balance_by_mixdepth(include_disabled=True) balances = wallet.get_balance_by_mixdepth(include_disabled=True)
assert balances[0] == num_tx * amount assert balances[0] == num_tx * amount

Loading…
Cancel
Save