From 1d498eeefc4d241589293a66863e255c276bb76d Mon Sep 17 00:00:00 2001 From: ThomasV Date: Tue, 12 Apr 2022 09:53:30 +0200 Subject: [PATCH] Change the semantics of get_balance: It does not make sense to count change outputs in our unconfirmed balance, because our balance will not be negatively affected if the transaction does not get confirmed. It is also incorrect to add signed values of get_addr_balance in order to compute the balance over a domain. For example, this leads to incoming and outgoing transactions cancelling out in our total unconfirmed balance. This commit looks at the coins that are spent by a transaction. If those coins belong to us and are confirmed, we do not count the transaction outputs in our unconfirmed balance. As a result, get_balance always returns positive values for unconfirmed balance. --- electrum/address_synchronizer.py | 101 ++++++++++++++++++------------- electrum/wallet.py | 5 -- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 60bc1b939..6911b2e8c 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -27,6 +27,7 @@ import itertools from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List +from .crypto import sha256 from . import bitcoin, util from .bitcoin import COINBASE_MATURITY from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock, OldTaskGroup @@ -94,7 +95,7 @@ class AddressSynchronizer(Logger): # thread local storage for caching stuff self.threadlocal_cache = threading.local() - self._get_addr_balance_cache = {} + self._get_balance_cache = {} self.load_and_cleanup() @@ -189,7 +190,7 @@ class AddressSynchronizer(Logger): util.register_callback(self.on_blockchain_updated, ['blockchain_updated']) def on_blockchain_updated(self, event, *args): - self._get_addr_balance_cache = {} # invalidate cache + self._get_balance_cache = {} # invalidate cache async def stop(self): if self.network: @@ -311,7 +312,7 @@ class AddressSynchronizer(Logger): pass else: self.db.add_txi_addr(tx_hash, addr, ser, v) - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache for txi in tx.inputs(): if txi.is_coinbase_input(): continue @@ -329,7 +330,7 @@ class AddressSynchronizer(Logger): addr = txo.address if addr and self.is_mine(addr): self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase) - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache # give v to txi that spends me next_tx = self.db.get_spent_outpoint(tx_hash, n) if next_tx is not None: @@ -379,7 +380,7 @@ class AddressSynchronizer(Logger): remove_from_spent_outpoints() self._remove_tx_from_local_history(tx_hash) for addr in itertools.chain(self.db.get_txi_addresses(tx_hash), self.db.get_txo_addresses(tx_hash)): - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache self.db.remove_txi(tx_hash) self.db.remove_txo(tx_hash) self.db.remove_tx_fee(tx_hash) @@ -465,7 +466,7 @@ class AddressSynchronizer(Logger): with self.transaction_lock: self.db.clear_history() self._history_local.clear() - self._get_addr_balance_cache = {} # invalidate cache + self._get_balance_cache.clear() # invalidate cache def get_txpos(self, tx_hash: str) -> Tuple[int, int]: """Returns (height, txpos) tuple, even if the tx is unverified.""" @@ -835,42 +836,72 @@ class AddressSynchronizer(Logger): received, sent = self.get_addr_io(address) return sum([v for height, v, is_cb in received.values()]) + def get_addr_balance(self, address): + return self.get_balance([address]) + @with_local_height_cached - def get_addr_balance(self, address, *, excluded_coins: Set[str] = None) -> Tuple[int, int, int]: - """Return the balance of a bitcoin address: + def get_balance(self, domain=None, *, excluded_addresses: Set[str] = None, + excluded_coins: Set[str] = None) -> Tuple[int, int, int]: + """Return the balance of a set of addresses: confirmed and matured, unconfirmed, unmatured """ - if not excluded_coins: # cache is only used if there are no excluded_coins - cached_value = self._get_addr_balance_cache.get(address) - if cached_value: - return cached_value + if domain is None: + domain = self.get_addresses() + if excluded_addresses is None: + excluded_addresses = set() + assert isinstance(excluded_addresses, set), f"excluded_addresses should be set, not {type(excluded_addresses)}" + domain = set(domain) - excluded_addresses if excluded_coins is None: excluded_coins = set() assert isinstance(excluded_coins, set), f"excluded_coins should be set, not {type(excluded_coins)}" - received, sent = self.get_addr_io(address) + + cache_key = sha256(''.join(sorted(domain)) + ''.join(excluded_coins)) + cached_value = self._get_balance_cache.get(cache_key) + if cached_value: + return cached_value + + coins = {} + for address in domain: + coins.update(self.get_addr_outputs(address)) + c = u = x = 0 mempool_height = self.get_local_height() + 1 # height of next block - for txo, (tx_height, v, is_cb) in received.items(): - if txo in excluded_coins: + for utxo in coins.values(): + if utxo.spent_height is not None: continue + if utxo.prevout.to_str() in excluded_coins: + continue + v = utxo.value_sats() + tx_height = utxo.block_height + is_cb = utxo._is_coinbase_output if is_cb and tx_height + COINBASE_MATURITY > mempool_height: x += v elif tx_height > 0: c += v else: - u += v - if txo in sent: - sent_txid, sent_height = sent[txo] - if sent_height > 0: - c -= v + txid = utxo.prevout.txid.hex() + tx = self.db.get_transaction(txid) + assert tx is not None # txid comes from get_addr_io + # we look at the outputs that are spent by this transaction + # if those outputs are ours and confirmed, we count this coin as confirmed + confirmed_spent_amount = 0 + for txin in tx.inputs(): + if txin.prevout in coins: + coin = coins[txin.prevout] + if coin.block_height > 0: + confirmed_spent_amount += coin.value_sats() + # Compare amount, in case tx has confirmed and unconfirmed inputs, or is a coinjoin. + # (fixme: tx may have multiple change outputs) + if confirmed_spent_amount >= v: + c += v else: - u -= v + c += confirmed_spent_amount + u += v - confirmed_spent_amount result = c, u, x # cache result. - if not excluded_coins: - # Cache needs to be invalidated if a transaction is added to/ - # removed from history; or on new blocks (maturity...) - self._get_addr_balance_cache[address] = result + # Cache needs to be invalidated if a transaction is added to/ + # removed from history; or on new blocks (maturity...) + self._get_balance_cache[cache_key] = result return result @with_local_height_cached @@ -918,28 +949,12 @@ class AddressSynchronizer(Logger): continue return coins - def get_balance(self, domain=None, *, excluded_addresses: Set[str] = None, - excluded_coins: Set[str] = None) -> Tuple[int, int, int]: - if domain is None: - domain = self.get_addresses() - if excluded_addresses is None: - excluded_addresses = set() - assert isinstance(excluded_addresses, set), f"excluded_addresses should be set, not {type(excluded_addresses)}" - domain = set(domain) - excluded_addresses - cc = uu = xx = 0 - for addr in domain: - c, u, x = self.get_addr_balance(addr, excluded_coins=excluded_coins) - cc += c - uu += u - xx += x - return cc, uu, xx - def is_used(self, address: str) -> bool: return self.get_address_history_len(address) != 0 def is_empty(self, address: str) -> bool: - c, u, x = self.get_addr_balance(address) - return c+u+x == 0 + coins = self.get_addr_utxo(address) + return not bool(coins) def synchronize(self) -> int: """Returns the number of new addresses we generated.""" diff --git a/electrum/wallet.py b/electrum/wallet.py index 9f0273d9d..217431b07 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -735,11 +735,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): uu = u - fu xx = x - fx frozen = fc + fu + fx - # subtract unconfirmed if negative. - # (this does not make sense if positive and negative tx cancel eachother out) - if uu < 0: - cc = cc + uu - uu = 0 return cc, uu, xx, frozen, lightning - f_lightning, f_lightning def balance_at_timestamp(self, domain, target_timestamp):