diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 4ae1405..cbaceb6 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -19,6 +19,7 @@ from itertools import chain from decimal import Decimal from numbers import Integral from math import exp +from typing import Any, Dict, Optional, Tuple from .configure import jm_single @@ -280,32 +281,28 @@ class UTXOManager(object): 'value': utxos[s['utxo']][1]} for s in selected} - def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), - include_disabled=True, maxheight=None): - """ By default this returns a dict of aggregated bitcoin - balance per mixdepth: {0: N sats, 1: M sats, ...} for all - currently available mixdepths. - If max_mixdepth is set it will return balances only up - to that mixdepth. + def get_balance_at_mixdepth(self, mixdepth: int, + include_disabled: bool = True, + maxheight: Optional[int] = None) -> int: + """ By default this returns aggregated bitcoin balance at mixdepth. To get only enabled balance, set include_disabled=False. To get balances only with a certain number of confs, use maxheight. """ - balance_dict = collections.defaultdict(int) - for mixdepth, utxomap in self._utxo.items(): - if mixdepth > max_mixdepth: - continue - if not include_disabled: - utxomap = {k: v for k, v in utxomap.items( - ) if not self.is_disabled(*k)} - if maxheight is not None: - utxomap = {k: v for k, v in utxomap.items( - ) if v[2] <= maxheight} - value = sum(x[1] for x in utxomap.values()) - balance_dict[mixdepth] = value - return balance_dict - - def get_utxos_by_mixdepth(self): - return deepcopy(self._utxo) + utxomap = self._utxo.get(mixdepth) + if not utxomap: + return 0 + if not include_disabled: + utxomap = {k: v for k, v in utxomap.items( + ) if not self.is_disabled(*k)} + if maxheight is not None: + utxomap = {k: v for k, v in utxomap.items( + ) if v[2] <= maxheight} + return sum(x[1] for x in utxomap.values()) + + def get_utxos_at_mixdepth(self, mixdepth: int) -> \ + Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]: + utxomap = self._utxo.get(mixdepth) + return deepcopy(utxomap) if utxomap else {} def __eq__(self, o): return self._utxo == o._utxo and \ @@ -836,10 +833,19 @@ class BaseWallet(object): confirmations, set maxheight to max acceptable blockheight. returns: {mixdepth: value} """ + balances = collections.defaultdict(int) + for md in range(self.mixdepth + 1): + balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose, + include_disabled=include_disabled, maxheight=maxheight) + return balances + + def get_balance_at_mixdepth(self, mixdepth, + verbose: bool = True, + include_disabled: bool = False, + maxheight: Optional[int] = None) -> int: # TODO: verbose - return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, - include_disabled=include_disabled, - maxheight=maxheight) + return self._utxos.get_balance_at_mixdepth(mixdepth, + include_disabled=include_disabled, maxheight=maxheight) def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): """ @@ -850,25 +856,35 @@ class BaseWallet(object): {'script': bytes, 'path': tuple, 'value': int}}} (if `includeheight` is True, adds key 'height': int) """ - mix_utxos = self._utxos.get_utxos_by_mixdepth() - script_utxos = collections.defaultdict(dict) - for md, data in mix_utxos.items(): - if md > self.mixdepth: - continue + for md in range(self.mixdepth + 1): + script_utxos[md] = self.get_utxos_at_mixdepth(md, + include_disabled=include_disabled, includeheight=includeheight) + return script_utxos + + def get_utxos_at_mixdepth(self, mixdepth: int, + include_disabled: bool = False, + includeheight: bool = False) -> \ + Dict[Tuple[bytes, int], Dict[str, Any]]: + script_utxos = {} + if 0 <= mixdepth <= self.mixdepth: + data = self._utxos.get_utxos_at_mixdepth(mixdepth) for utxo, (path, value, height) in data.items(): if not include_disabled and self._utxos.is_disabled(*utxo): continue script = self.get_script_from_path(path) addr = self.get_address_from_path(path) label = self.get_address_label(addr) - script_utxos[md][utxo] = {'script': script, - 'path': path, - 'value': value, - 'address': addr, - 'label': label} + script_utxo = { + 'script': script, + 'path': path, + 'value': value, + 'address': addr, + 'label': label, + } if includeheight: - script_utxos[md][utxo]['height'] = height + script_utxo['height'] = height + script_utxos[utxo] = script_utxo return script_utxos diff --git a/test/jmclient/test_utxomanager.py b/test/jmclient/test_utxomanager.py index 2d3023f..1bd97e1 100644 --- a/test/jmclient/test_utxomanager.py +++ b/test/jmclient/test_utxomanager.py @@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert not um.is_disabled(txid, index+2) um.disable_utxo(txid, index+2) - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 1 - assert len(utxos[mixdepth+1]) == 2 - assert len(utxos[mixdepth+2]) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2 + assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == value - assert balances[mixdepth+1] == value * 2 + assert um.get_balance_at_mixdepth(mixdepth) == value + assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2 um.remove_utxo(txid, index, mixdepth) assert um.have_utxo(txid, index) == False @@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index+1) == mixdepth + 1 - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 0 - assert len(utxos[mixdepth+1]) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == 0 - assert balances[mixdepth+1] == value - assert balances[mixdepth+2] == 0 + assert um.get_balance_at_mixdepth(mixdepth) == 0 + assert um.get_balance_at_mixdepth(mixdepth+1) == value + assert um.get_balance_at_mixdepth(mixdepth+2) == 0 def test_utxomanager_select(setup_env_nodeps): diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 86d5d8e..ab68e72 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -477,7 +477,7 @@ def test_get_bbm(setup_wallet): wallet = get_populated_wallet(amount, num_tx) # disable a utxo and check we can correctly report # balance with the disabled flag off: - utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] + utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0] wallet.disable_utxo(*utxo_1) balances = wallet.get_balance_by_mixdepth(include_disabled=True) assert balances[0] == num_tx * amount