diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 60bc1b939..6911b2e8c 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -27,6 +27,7 @@ import itertools from collections import defaultdict from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List +from .crypto import sha256 from . import bitcoin, util from .bitcoin import COINBASE_MATURITY from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException, with_lock, OldTaskGroup @@ -94,7 +95,7 @@ class AddressSynchronizer(Logger): # thread local storage for caching stuff self.threadlocal_cache = threading.local() - self._get_addr_balance_cache = {} + self._get_balance_cache = {} self.load_and_cleanup() @@ -189,7 +190,7 @@ class AddressSynchronizer(Logger): util.register_callback(self.on_blockchain_updated, ['blockchain_updated']) def on_blockchain_updated(self, event, *args): - self._get_addr_balance_cache = {} # invalidate cache + self._get_balance_cache = {} # invalidate cache async def stop(self): if self.network: @@ -311,7 +312,7 @@ class AddressSynchronizer(Logger): pass else: self.db.add_txi_addr(tx_hash, addr, ser, v) - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache for txi in tx.inputs(): if txi.is_coinbase_input(): continue @@ -329,7 +330,7 @@ class AddressSynchronizer(Logger): addr = txo.address if addr and self.is_mine(addr): self.db.add_txo_addr(tx_hash, addr, n, v, is_coinbase) - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache # give v to txi that spends me next_tx = self.db.get_spent_outpoint(tx_hash, n) if next_tx is not None: @@ -379,7 +380,7 @@ class AddressSynchronizer(Logger): remove_from_spent_outpoints() self._remove_tx_from_local_history(tx_hash) for addr in itertools.chain(self.db.get_txi_addresses(tx_hash), self.db.get_txo_addresses(tx_hash)): - self._get_addr_balance_cache.pop(addr, None) # invalidate cache + self._get_balance_cache.clear() # invalidate cache self.db.remove_txi(tx_hash) self.db.remove_txo(tx_hash) self.db.remove_tx_fee(tx_hash) @@ -465,7 +466,7 @@ class AddressSynchronizer(Logger): with self.transaction_lock: self.db.clear_history() self._history_local.clear() - self._get_addr_balance_cache = {} # invalidate cache + self._get_balance_cache.clear() # invalidate cache def get_txpos(self, tx_hash: str) -> Tuple[int, int]: """Returns (height, txpos) tuple, even if the tx is unverified.""" @@ -835,42 +836,72 @@ class AddressSynchronizer(Logger): received, sent = self.get_addr_io(address) return sum([v for height, v, is_cb in received.values()]) + def get_addr_balance(self, address): + return self.get_balance([address]) + @with_local_height_cached - def get_addr_balance(self, address, *, excluded_coins: Set[str] = None) -> Tuple[int, int, int]: - """Return the balance of a bitcoin address: + def get_balance(self, domain=None, *, excluded_addresses: Set[str] = None, + excluded_coins: Set[str] = None) -> Tuple[int, int, int]: + """Return the balance of a set of addresses: confirmed and matured, unconfirmed, unmatured """ - if not excluded_coins: # cache is only used if there are no excluded_coins - cached_value = self._get_addr_balance_cache.get(address) - if cached_value: - return cached_value + if domain is None: + domain = self.get_addresses() + if excluded_addresses is None: + excluded_addresses = set() + assert isinstance(excluded_addresses, set), f"excluded_addresses should be set, not {type(excluded_addresses)}" + domain = set(domain) - excluded_addresses if excluded_coins is None: excluded_coins = set() assert isinstance(excluded_coins, set), f"excluded_coins should be set, not {type(excluded_coins)}" - received, sent = self.get_addr_io(address) + + cache_key = sha256(''.join(sorted(domain)) + ''.join(excluded_coins)) + cached_value = self._get_balance_cache.get(cache_key) + if cached_value: + return cached_value + + coins = {} + for address in domain: + coins.update(self.get_addr_outputs(address)) + c = u = x = 0 mempool_height = self.get_local_height() + 1 # height of next block - for txo, (tx_height, v, is_cb) in received.items(): - if txo in excluded_coins: + for utxo in coins.values(): + if utxo.spent_height is not None: continue + if utxo.prevout.to_str() in excluded_coins: + continue + v = utxo.value_sats() + tx_height = utxo.block_height + is_cb = utxo._is_coinbase_output if is_cb and tx_height + COINBASE_MATURITY > mempool_height: x += v elif tx_height > 0: c += v else: - u += v - if txo in sent: - sent_txid, sent_height = sent[txo] - if sent_height > 0: - c -= v + txid = utxo.prevout.txid.hex() + tx = self.db.get_transaction(txid) + assert tx is not None # txid comes from get_addr_io + # we look at the outputs that are spent by this transaction + # if those outputs are ours and confirmed, we count this coin as confirmed + confirmed_spent_amount = 0 + for txin in tx.inputs(): + if txin.prevout in coins: + coin = coins[txin.prevout] + if coin.block_height > 0: + confirmed_spent_amount += coin.value_sats() + # Compare amount, in case tx has confirmed and unconfirmed inputs, or is a coinjoin. + # (fixme: tx may have multiple change outputs) + if confirmed_spent_amount >= v: + c += v else: - u -= v + c += confirmed_spent_amount + u += v - confirmed_spent_amount result = c, u, x # cache result. - if not excluded_coins: - # Cache needs to be invalidated if a transaction is added to/ - # removed from history; or on new blocks (maturity...) - self._get_addr_balance_cache[address] = result + # Cache needs to be invalidated if a transaction is added to/ + # removed from history; or on new blocks (maturity...) + self._get_balance_cache[cache_key] = result return result @with_local_height_cached @@ -918,28 +949,12 @@ class AddressSynchronizer(Logger): continue return coins - def get_balance(self, domain=None, *, excluded_addresses: Set[str] = None, - excluded_coins: Set[str] = None) -> Tuple[int, int, int]: - if domain is None: - domain = self.get_addresses() - if excluded_addresses is None: - excluded_addresses = set() - assert isinstance(excluded_addresses, set), f"excluded_addresses should be set, not {type(excluded_addresses)}" - domain = set(domain) - excluded_addresses - cc = uu = xx = 0 - for addr in domain: - c, u, x = self.get_addr_balance(addr, excluded_coins=excluded_coins) - cc += c - uu += u - xx += x - return cc, uu, xx - def is_used(self, address: str) -> bool: return self.get_address_history_len(address) != 0 def is_empty(self, address: str) -> bool: - c, u, x = self.get_addr_balance(address) - return c+u+x == 0 + coins = self.get_addr_utxo(address) + return not bool(coins) def synchronize(self) -> int: """Returns the number of new addresses we generated.""" diff --git a/electrum/wallet.py b/electrum/wallet.py index 9f0273d9d..217431b07 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -735,11 +735,6 @@ class Abstract_Wallet(AddressSynchronizer, ABC): uu = u - fu xx = x - fx frozen = fc + fu + fx - # subtract unconfirmed if negative. - # (this does not make sense if positive and negative tx cancel eachother out) - if uu < 0: - cc = cc + uu - uu = 0 return cc, uu, xx, frozen, lightning - f_lightning, f_lightning def balance_at_timestamp(self, domain, target_timestamp):