You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

592 lines
22 KiB

# -*- coding: utf-8 -*-
import asyncio
import attr
import threading
from aiorpcx import run_in_thread
from electrum import util
from electrum.bitcoin import (pubkey_to_address, address_to_scripthash,
is_address)
from electrum.interface import NetworkException
from electrum.transaction import Transaction
from electrum.util import EventListener, event_listener, ignore_exceptions
from .jm_base_code import JMBaseCodeMixin
from .jm_util import JMAddress, JMUtxo, KPStates
class KeypairNotFound(Exception):
...
class KeyPairsMixin:
'''Cached keypairs for automatic tx signing'''
def __init__(self):
self.keypairs_state_lock = threading.Lock()
self._keypairs_state = KPStates.Empty
self._keypairs_cache = {}
@property
def keypairs_state(self):
'''Get keypairs cache state'''
return self._keypairs_state
@keypairs_state.setter
def keypairs_state(self, keypairs_state):
'''Set keypairs cache state'''
assert isinstance(keypairs_state, KPStates)
self._keypairs_state = keypairs_state
def check_need_new_keypairs(self):
'''Check if there is a need to cache new keypairs in addition
to possibly already cached ones'''
if not self.jmman.need_password():
return False
with self.keypairs_state_lock:
return self.keypairs_state == KPStates.Empty
async def cleanup_keypairs(self):
'''Async task which cleans keypairs after mixing is stopped'''
def cleanup_keypairs_cache():
'''Cleanup keypairs cache'''
with self.keypairs_state_lock:
self.logger.info('Cleaning Keypairs Cache')
if self._keypairs_cache:
for addr in list(self._keypairs_cache.keys()):
self._keypairs_cache.pop(addr)
self.keypairs_state = KPStates.Empty
self.logger.info('Cleaned Keypairs Cache')
await self.loop.run_in_executor(None, cleanup_keypairs_cache)
async def make_keypairs_cache(self, password, keypairs_cached_callback,
*, tx_cnt=None):
'''Make keypairs cache after mixing is started'''
if self.keypairs_state == KPStates.Ready:
return False
try:
def _cache_keypairs():
return self._cache_keypairs(password, tx_cnt=tx_cnt)
self.logger.info('Making Keypairs Cache')
cached = await self.loop.run_in_executor(None, _cache_keypairs)
self.logger.info(f'Keypairs Cache Done, cached {cached} keys')
if keypairs_cached_callback:
try:
keypairs_cached_callback()
except BaseException as e:
self.logger.info(f'make_keypairs_cache: '
f'keypairs_cached_callback: {str(e)}')
return True
except Exception as e:
self.logger.warning(f'make_keypairs_cache: {str(e)}')
await self.cleanup_keypairs()
return False
def _cache_keypairs(self, password, *, tx_cnt=None):
'''Cache keypairs on mixing start'''
w = self.wallet
jmconf = self.jmconf
cached = 0
out_cnt = 2
with self.keypairs_state_lock:
if self.keypairs_state == KPStates.Ready:
return cached
if isinstance(tx_cnt, int) and tx_cnt >= 0:
# cache keys for existing JM utxos
for jm_utxo in self.get_jm_utxos().values():
addr = jm_utxo.addr
sequence = w.get_address_index(addr)
pubkey = w.keystore.derive_pubkey(*sequence)
sec, _ = w.keystore.get_private_key(sequence, password)
self._keypairs_cache[addr] = (pubkey, sec)
cached += 1
if tx_cnt > 0:
def filter_unused(addr):
return w.adb.get_address_history_len(addr) == 0
mixdepth = jmconf.mixdepth + 1
gaplimit = jmconf.gaplimit
key_cnt = out_cnt * tx_cnt * mixdepth * gaplimit + cached
unused_idxs = []
# cache keys for unused adresses by mixdepth
for d in range(mixdepth):
addrs = [a for a, jm_address in sorted(
self.get_jm_addresses(mixdepth=d,
internal=True).items(),
key=lambda x: x[1].index[1])]
addrs = list(filter(filter_unused, addrs))
if addrs:
sequence = w.get_address_index(addrs[0])
if sequence:
unused_idxs.append(sequence[1])
depth_cached = 0
for addr in addrs:
if depth_cached + 1 > out_cnt * tx_cnt:
break
sequence = w.get_address_index(addr)
if addr in self._keypairs_cache:
continue # skip cached
pubkey = w.keystore.derive_pubkey(*sequence)
sec, _ = w.keystore.get_private_key(sequence,
password)
self._keypairs_cache[addr] = (pubkey, sec)
depth_cached += 1
cached += 1
# cache keys for unused adresses not in cache
# starting from mininum unused index
idx = min(unused_idxs) if unused_idxs else 0
while cached < key_cnt:
sequence = (1, idx)
idx += 1
pubkey = w.keystore.derive_pubkey(*sequence)
addr = pubkey_to_address(w.txin_type, pubkey.hex())
if addr in self._keypairs_cache:
continue # skip cached
if w.adb.get_address_history_len(addr) > 0:
continue # skip used
sec, _ = w.keystore.get_private_key(sequence, password)
self._keypairs_cache[addr] = (pubkey, sec)
cached += 1
else:
for addr in self.get_unspent_jm_addresses():
sequence = w.get_address_index(addr)
pubkey = w.keystore.derive_pubkey(*sequence)
sec, _ = w.keystore.get_private_key(sequence, password)
self._keypairs_cache[addr] = (pubkey, sec)
cached += 1
self.keypairs_state = KPStates.Ready
return cached
def get_cached_key(self, addr):
if addr in self._keypairs_cache:
return self._keypairs_cache[addr][1]
else:
self.logger.error(f'get_key_from_addr: keypair'
f' for {addr} not found')
raise KeypairNotFound(f'Address {addr} not found '
f'in the keypairs cache')
def get_keypairs(self):
'''Transform keypairs cache to dict suitable for Transaction.sign'''
keypairs = {}
for pubkey, sec in self._keypairs_cache.values():
keypairs[pubkey] = sec
return keypairs
def get_keypairs_for_coinjoin_tx(self, tx, password):
'''Derive keypairs for coinjoin tx to fix add_info_from_wallet
problem on transactions with unknown addresses (post 4.1.х fix)'''
w = self.wallet
keypairs = {}
for txin in tx.inputs():
addr = txin.address
if addr is None or not w.is_mine(addr):
continue
sequence = w.get_address_index(addr)
pubkey = w.keystore.derive_pubkey(*sequence)
sec, _ = w.keystore.get_private_key(sequence, password)
keypairs[pubkey] = sec
return keypairs
def sign_coinjoin_transaction(self, tx, password=None):
'''Sign coinjoin transactions (with keypairs if cached)'''
if self._keypairs_cache:
keypairs = self.get_keypairs()
else:
keypairs = self.get_keypairs_for_coinjoin_tx(tx, password)
tx.sign(keypairs)
tx.finalize_psbt()
keypairs.clear()
if tx.is_complete():
return tx
class WalletDBMixin:
def wallet_db_modifier(func):
def wrapper(self, *args, **kwargs):
with self.db.lock:
self.db._modified = True
return func(self, *args, **kwargs)
return wrapper
def wallet_db_locked(func):
def wrapper(self, *args, **kwargs):
with self.db.lock:
return func(self, *args, **kwargs)
return wrapper
@wallet_db_locked
def get_jm_data(self, key, default_val=None):
return self.jm_data.get(key, default_val)
@wallet_db_modifier
def set_jm_data(self, key, val):
self.jm_data[key] = val
@wallet_db_modifier
def pop_jm_data(self, key):
return self.jm_data.pop(key, None)
@wallet_db_modifier
def set_jm_commitments(self, *, used, external):
self.jm_commitments['used'] = used
self.jm_commitments['external'] = external
@wallet_db_locked
def get_jm_commitments(self):
return self.jm_commitments
@wallet_db_modifier
def add_jm_address(self, address, jm_address):
'''
add add jm_address (mixdepth, branch, index, address) tuple
at address BIP32 path
'''
assert isinstance(jm_address, JMAddress)
self.jm_addresses[address] = attr.astuple(jm_address)
@wallet_db_locked
def get_jm_address(self, address):
jm_address_tuple = self.jm_addresses.get(address)
if jm_address_tuple:
return JMAddress(*jm_address_tuple)
@wallet_db_locked
def is_jm_address(self, address):
return address in self.jm_addresses
@wallet_db_locked
def get_jm_addresses(self, *, mixdepth=None, internal=None):
addresses = {k: JMAddress(*v) for k, v in self.jm_addresses.items()}
if mixdepth is not None:
addresses = {k: v for k, v in addresses.items()
if v.mixdepth == mixdepth}
if internal is not None:
addresses = {k: v for k, v in addresses.items()
if v.branch == int(internal)}
return addresses
def get_jm_utxos(self, *, mixdepth=None, internal=None):
addrs = self.get_jm_addresses(mixdepth=mixdepth, internal=internal)
coins = self.wallet.get_utxos(list(addrs.keys()))
res = {}
for c in coins:
outpoint = c.prevout.to_str()
addr = c.address
jm_addr = addrs[addr]
res[outpoint] = JMUtxo(addr, c.value_sats(), jm_addr.mixdepth)
return res
@wallet_db_modifier
def add_jm_tx(self, txid, address, amount, date):
self.jm_txs[txid] = (address, amount, date)
@wallet_db_locked
def get_jm_tx(self, txid):
return self.jm_txs.get(txid, None)
@wallet_db_locked
def get_jm_txs(self):
return self.jm_txs
class JMWallet(KeyPairsMixin, WalletDBMixin, JMBaseCodeMixin, EventListener):
def __init__(self, jmman):
KeyPairsMixin.__init__(self)
JMBaseCodeMixin.__init__(self)
self._jm_conf = None
self.jmman = jmman
self.logger = jmman.logger
self.wallet = jmman.wallet
self.config = jmman.config
self.debug = False
self.db = self.wallet.db
self.jm_data = None
self.jm_addresses = None
self.jm_commitments = None
self.jm_txs = None
# sycnhronizer unsubsribed addresses
self.spent_addrs = set()
self.unsubscribed_addrs = set()
# ignored makers list persisted across entire app run
self.ignored_makers = []
# from jmclient wallet service
self.callbacks = {
"all": [], # note: list, not dict
"unconfirmed": {},
"confirmed": {},
}
# transactions we are actively monitoring,
# i.e. they are not new but we want to track:
self.active_txs = {}
# to ensure transactions are only processed once:
self.processed_txids = set()
self.taskgroup = util.OldTaskGroup()
def init_jm_data(self):
if not self.jmman.enabled:
return
db = self.db
self.jm_data = db.get_dict('jm_data')
self.jm_addresses = db.get_dict('jm_addresses')
self.jm_commitments = db.get_dict('jm_commitments')
self.jm_txs = db.get_dict('jm_txs')
@property
def jmconf(self):
return self._jm_conf
@jmconf.setter
def jmconf(self, jmconf):
self._jm_conf = jmconf
def load_and_cleanup(self):
'''Start on wallet load_and_cleanup if JM enabled
or when JM enabled first time'''
if not self.jmman.enabled:
return
self.synchronize()
# load and unsubscribe spent JM addresses
self.add_spent_addrs(self.get_jm_addresses().keys())
def get_address_label(self, addr):
return self.wallet.get_label_for_address(addr)
def set_address_label(self, addr, label):
self.wallet.set_label(addr, label)
def add_spent_addrs(self, addrs):
'''Save addresses as spent, to minimize electrum server
usage for spent denoms'''
unspent = self.get_unspent_jm_addresses()
for addr in addrs:
if addr in unspent:
continue
self.spent_addrs.add(addr)
if self.jmconf.subscribe_spent:
continue
self.unsubscribe_spent_addr(addr)
def restore_spent_addrs(self, addrs):
'''Remove addresses from spent and subscribe on server again'''
for addr in addrs:
self.subscribe_spent_addr(addr)
self.spent_addrs.remove(addr)
def subscribe_spent_addr(self, addr):
'''Return previously unsubscribed address to synchronizer'''
if addr not in self.spent_addrs or addr not in self.unsubscribed_addrs:
return
w = self.wallet
self.unsubscribed_addrs.remove(addr)
if w.adb.synchronizer:
self.logger.debug(f'Add {addr} to synchronizer')
w.adb.synchronizer.add(addr)
def unsubscribe_spent_addr(self, addr):
'''Unsubscribe spent address from synchronizer/electrum server'''
if (self.jmconf.subscribe_spent
or addr not in self.spent_addrs
or addr in self.unsubscribed_addrs):
return
self.unsubscribed_addrs.add(addr)
self.synchronizer_remove_addr(addr)
def synchronizer_remove_addr(self, addr):
w = self.wallet
synchronizer = w.adb.synchronizer
if synchronizer:
if self.debug:
self.logger.debug(f'Remove {addr} from synchronizer')
async def _remove_address(addr: str):
if not is_address(addr):
raise ValueError(f"invalid bitcoin address {addr}")
h = address_to_scripthash(addr)
synchronizer._requests_sent += 1
async with synchronizer._network_request_semaphore:
await synchronizer.session.send_request(
'blockchain.scripthash.unsubscribe', [h])
synchronizer._requests_answered += 1
asyncio.run_coroutine_threadsafe(_remove_address(addr), self.loop)
def reserve_jm_addrs(self, addrs_count, *, internal=False):
'''Reserve addresses for JM use'''
result = []
w = self.wallet
with w.lock:
while len(result) < addrs_count:
if internal:
unused = w.calc_unused_change_addresses()
else:
unused = w.get_unused_addresses()
unused = [addr for addr in unused
if not w.is_address_reserved(addr)]
if unused:
addr = unused[0]
else:
addr = w.create_new_address(internal)
self.wallet.set_reserved_state_of_address(addr, reserved=True)
result.append(addr)
return result
def last_few_addresses(self, jm_addrs, limit=0):
sorted_addrs = sorted(jm_addrs.items(), key=lambda x: x[1].index[1])
return [a for a, data in sorted_addrs][-limit:]
def generate_jm_address(self, *, mixdepth, internal):
addrs = self.reserve_jm_addrs(1, internal=internal)
if not addrs:
self.logger.error(f'Error generating new address for'
f' mixdepth={mixdepth}, internal={internal}')
return False
addr = addrs[0]
index = self.wallet.get_address_index(addr)
jm_addr = JMAddress(mixdepth=mixdepth, branch=int(internal),
index=index)
self.add_jm_address(addr, jm_addr)
return True
def synchronize_sequence(self, mixdepth: int, internal: bool) -> int:
w = self.wallet
gen_cnt = 0 # num new addresses we generated
limit = self.jmconf.gaplimit
while True:
jm_addrs = self.get_jm_addresses(mixdepth=mixdepth,
internal=internal)
addr_cnt = len(jm_addrs)
if addr_cnt < limit:
if not self.generate_jm_address(mixdepth=mixdepth,
internal=internal):
return gen_cnt
gen_cnt += 1
continue
last_few_addrs = self.last_few_addresses(jm_addrs, limit)
if any(map(w.adb.address_is_old, last_few_addrs)):
if not self.generate_jm_address(mixdepth=mixdepth,
internal=internal):
return gen_cnt
gen_cnt += 1
else:
break
return gen_cnt
def synchronize(self):
if not self.jmman.enabled:
return
count = 0
with self.wallet.lock:
for d in range(self.jmconf.max_mixdepth + 1):
for i in range(2):
count += self.synchronize_sequence(mixdepth=d,
internal=bool(i))
return count
def on_network_start(self, network):
'''Run when network is connected to the wallet'''
asyncio.run_coroutine_threadsafe(self.main_loop(), self.loop)
@ignore_exceptions # don't kill outer taskgroup
async def main_loop(self):
try:
async with self.taskgroup as group:
await group.spawn(self.do_synchronize_loop())
except BaseException:
self.logger.exception("taskgroup died.")
finally:
self.logger.info("taskgroup stopped.")
async def do_synchronize_loop(self):
while True:
if self.jmman.enabled:
# note: we only generate new HD addresses if the existing ones
# have history that are mined and SPV-verified.
await run_in_thread(self.synchronize)
await asyncio.sleep(1)
async def get_tx(self, txid, *, ignore_network_issues=True, timeout=None):
tx = self.wallet.db.get_transaction(txid)
if tx:
return tx
if self.network and self.network.has_internet_connection():
try:
raw_tx = await self.network.get_transaction(txid,
timeout=timeout)
except NetworkException as e:
self.logger.info(f'got network error getting input txn. err:'
f' {repr(e)}. txid: {txid}.')
if not ignore_network_issues:
raise e
else:
tx = Transaction(raw_tx)
if not tx and not ignore_network_issues:
raise NetworkException('failed to get prev tx from network')
return tx
def get_spent_jm_addresses(self, *, mixdepth=None, internal=None):
jm_addr_list = set()
for addr in self.get_jm_addresses(mixdepth=mixdepth,
internal=internal).keys():
if self.wallet.adb.get_address_history_len(addr) >= 2:
jm_addr_list.add(addr)
return jm_addr_list
def get_unspent_jm_addresses(self, *, mixdepth=None, internal=None):
jm_addr_list = set()
for addr in self.get_jm_addresses(mixdepth=mixdepth,
internal=internal).keys():
if self.wallet.adb.get_address_history_len(addr) < 2:
jm_addr_list.add(addr)
return set(jm_addr_list)
@event_listener
async def on_event_adb_added_tx(self, adb, txid: str, tx: Transaction):
if self.wallet.adb != adb:
return
try:
await self.transaction_monitor(tx, txid)
except Exception as e:
self.logger.warning(f'on_event_adb_added_tx: {str(e)}')
@event_listener
async def on_event_adb_added_verified_tx(self, adb, txid):
if self.wallet.adb != adb:
return
try:
tx = self.wallet.adb.get_transaction(txid)
if tx:
await self.transaction_monitor(tx, txid)
else:
self.logger.debug(f'on_event_adb_added_verified_tx: tx not'
f' found for txid={txid}')
except Exception as e:
self.logger.warning(f'on_event_adb_added_verified_tx: {str(e)}')
@event_listener
async def on_event_adb_tx_height_changed(self, adb, txid,
old_height, tx_height):
if self.wallet.adb != adb:
return
try:
tx = self.wallet.adb.get_transaction(txid)
await self.transaction_monitor(tx, txid)
except Exception as e:
self.logger.warning(f'on_event_adb_tx_height_changed: {str(e)}')