From 184d76f7f7ac81b0e566c5c5f2ea89d9f7da2258 Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 16:22:14 -0400 Subject: [PATCH] wallet: add get_{balance,utxos}_at_mixdepth methods Sometimes calling code is only interested in the balance or UTXOs at a single mixdepth. In these cases, it is wasteful to get the balance or UTXOs at all mixdepths, only to throw away the returned information about all but the single mixdepth of interest. Implement new methods in BaseWallet to get the balance or UTXOs at a single mixdepth. Also, correct an apparent oversight due to apparently misplaced indentation: the maxheight parameter of get_balance_by_mixdepth was ignored unless the include_disabled parameter was passed as False. It appears that the intention was for include_disabled and maxheight to be independent filters on the returned information. --- src/jmclient/wallet.py | 90 ++++++++++++++++++------------- test/jmclient/test_utxomanager.py | 24 ++++----- test/jmclient/test_wallet.py | 2 +- 3 files changed, 64 insertions(+), 52 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 4ae1405..cbaceb6 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -19,6 +19,7 @@ from itertools import chain from decimal import Decimal from numbers import Integral from math import exp +from typing import Any, Dict, Optional, Tuple from .configure import jm_single @@ -280,32 +281,28 @@ class UTXOManager(object): 'value': utxos[s['utxo']][1]} for s in selected} - def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), - include_disabled=True, maxheight=None): - """ By default this returns a dict of aggregated bitcoin - balance per mixdepth: {0: N sats, 1: M sats, ...} for all - currently available mixdepths. - If max_mixdepth is set it will return balances only up - to that mixdepth. + def get_balance_at_mixdepth(self, mixdepth: int, + include_disabled: bool = True, + maxheight: Optional[int] = None) -> int: + """ By default this returns aggregated bitcoin balance at mixdepth. To get only enabled balance, set include_disabled=False. To get balances only with a certain number of confs, use maxheight. """ - balance_dict = collections.defaultdict(int) - for mixdepth, utxomap in self._utxo.items(): - if mixdepth > max_mixdepth: - continue - if not include_disabled: - utxomap = {k: v for k, v in utxomap.items( - ) if not self.is_disabled(*k)} - if maxheight is not None: - utxomap = {k: v for k, v in utxomap.items( - ) if v[2] <= maxheight} - value = sum(x[1] for x in utxomap.values()) - balance_dict[mixdepth] = value - return balance_dict - - def get_utxos_by_mixdepth(self): - return deepcopy(self._utxo) + utxomap = self._utxo.get(mixdepth) + if not utxomap: + return 0 + if not include_disabled: + utxomap = {k: v for k, v in utxomap.items( + ) if not self.is_disabled(*k)} + if maxheight is not None: + utxomap = {k: v for k, v in utxomap.items( + ) if v[2] <= maxheight} + return sum(x[1] for x in utxomap.values()) + + def get_utxos_at_mixdepth(self, mixdepth: int) -> \ + Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]: + utxomap = self._utxo.get(mixdepth) + return deepcopy(utxomap) if utxomap else {} def __eq__(self, o): return self._utxo == o._utxo and \ @@ -836,10 +833,19 @@ class BaseWallet(object): confirmations, set maxheight to max acceptable blockheight. returns: {mixdepth: value} """ + balances = collections.defaultdict(int) + for md in range(self.mixdepth + 1): + balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose, + include_disabled=include_disabled, maxheight=maxheight) + return balances + + def get_balance_at_mixdepth(self, mixdepth, + verbose: bool = True, + include_disabled: bool = False, + maxheight: Optional[int] = None) -> int: # TODO: verbose - return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, - include_disabled=include_disabled, - maxheight=maxheight) + return self._utxos.get_balance_at_mixdepth(mixdepth, + include_disabled=include_disabled, maxheight=maxheight) def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): """ @@ -850,25 +856,35 @@ class BaseWallet(object): {'script': bytes, 'path': tuple, 'value': int}}} (if `includeheight` is True, adds key 'height': int) """ - mix_utxos = self._utxos.get_utxos_by_mixdepth() - script_utxos = collections.defaultdict(dict) - for md, data in mix_utxos.items(): - if md > self.mixdepth: - continue + for md in range(self.mixdepth + 1): + script_utxos[md] = self.get_utxos_at_mixdepth(md, + include_disabled=include_disabled, includeheight=includeheight) + return script_utxos + + def get_utxos_at_mixdepth(self, mixdepth: int, + include_disabled: bool = False, + includeheight: bool = False) -> \ + Dict[Tuple[bytes, int], Dict[str, Any]]: + script_utxos = {} + if 0 <= mixdepth <= self.mixdepth: + data = self._utxos.get_utxos_at_mixdepth(mixdepth) for utxo, (path, value, height) in data.items(): if not include_disabled and self._utxos.is_disabled(*utxo): continue script = self.get_script_from_path(path) addr = self.get_address_from_path(path) label = self.get_address_label(addr) - script_utxos[md][utxo] = {'script': script, - 'path': path, - 'value': value, - 'address': addr, - 'label': label} + script_utxo = { + 'script': script, + 'path': path, + 'value': value, + 'address': addr, + 'label': label, + } if includeheight: - script_utxos[md][utxo]['height'] = height + script_utxo['height'] = height + script_utxos[utxo] = script_utxo return script_utxos diff --git a/test/jmclient/test_utxomanager.py b/test/jmclient/test_utxomanager.py index 2d3023f..1bd97e1 100644 --- a/test/jmclient/test_utxomanager.py +++ b/test/jmclient/test_utxomanager.py @@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert not um.is_disabled(txid, index+2) um.disable_utxo(txid, index+2) - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 1 - assert len(utxos[mixdepth+1]) == 2 - assert len(utxos[mixdepth+2]) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2 + assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == value - assert balances[mixdepth+1] == value * 2 + assert um.get_balance_at_mixdepth(mixdepth) == value + assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2 um.remove_utxo(txid, index, mixdepth) assert um.have_utxo(txid, index) == False @@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index+1) == mixdepth + 1 - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 0 - assert len(utxos[mixdepth+1]) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == 0 - assert balances[mixdepth+1] == value - assert balances[mixdepth+2] == 0 + assert um.get_balance_at_mixdepth(mixdepth) == 0 + assert um.get_balance_at_mixdepth(mixdepth+1) == value + assert um.get_balance_at_mixdepth(mixdepth+2) == 0 def test_utxomanager_select(setup_env_nodeps): diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 86d5d8e..ab68e72 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -477,7 +477,7 @@ def test_get_bbm(setup_wallet): wallet = get_populated_wallet(amount, num_tx) # disable a utxo and check we can correctly report # balance with the disabled flag off: - utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] + utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0] wallet.disable_utxo(*utxo_1) balances = wallet.get_balance_by_mixdepth(include_disabled=True) assert balances[0] == num_tx * amount