You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
318 lines
12 KiB
318 lines
12 KiB
# Copyright (C) 2018 The Electrum developers |
|
# Distributed under the MIT software license, see the accompanying |
|
# file LICENCE or http://www.opensource.org/licenses/mit-license.php |
|
|
|
from typing import NamedTuple, Iterable, TYPE_CHECKING |
|
import os |
|
import queue |
|
import threading |
|
import concurrent |
|
from collections import defaultdict |
|
import asyncio |
|
from enum import IntEnum, auto |
|
from typing import NamedTuple, Dict |
|
|
|
import jsonrpclib |
|
|
|
from .util import PrintError, bh2u, bfh, log_exceptions, ignore_exceptions |
|
from . import wallet |
|
from .storage import WalletStorage |
|
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED |
|
from .transaction import Transaction |
|
|
|
if TYPE_CHECKING: |
|
from .network import Network |
|
|
|
class ListenerItem(NamedTuple): |
|
# this is triggered when the lnwatcher is all done with the outpoint used as index in LNWatcher.tx_progress |
|
all_done : asyncio.Event |
|
# txs we broadcast are put on this queue so that the test can wait for them to get mined |
|
tx_queue : asyncio.Queue |
|
|
|
class TxMinedDepth(IntEnum): |
|
""" IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """ |
|
DEEP = auto() |
|
SHALLOW = auto() |
|
MEMPOOL = auto() |
|
FREE = auto() |
|
|
|
|
|
from sqlalchemy import create_engine, Column, ForeignKey, Integer, String, DateTime, Boolean |
|
from sqlalchemy.pool import StaticPool |
|
from sqlalchemy.orm import sessionmaker |
|
from sqlalchemy.orm.query import Query |
|
from sqlalchemy.ext.declarative import declarative_base |
|
from sqlalchemy.sql import not_, or_ |
|
from sqlalchemy.orm import scoped_session |
|
|
|
Base = declarative_base() |
|
|
|
class SweepTx(Base): |
|
__tablename__ = 'sweep_txs' |
|
funding_outpoint = Column(String(34)) |
|
prev_txid = Column(String(32)) |
|
tx = Column(String()) |
|
txid = Column(String(32), primary_key=True) # txid of tx |
|
|
|
class ChannelInfo(Base): |
|
__tablename__ = 'channel_info' |
|
address = Column(String(32), primary_key=True) |
|
outpoint = Column(String(34)) |
|
|
|
|
|
class SweepStore(PrintError): |
|
|
|
def __init__(self, path, network): |
|
PrintError.__init__(self) |
|
self.path = path |
|
self.network = network |
|
self.db_requests = queue.Queue() |
|
threading.Thread(target=self.sql_thread).start() |
|
|
|
def sql_thread(self): |
|
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool) |
|
DBSession = sessionmaker(bind=engine, autoflush=False) |
|
self.DBSession = DBSession() |
|
if not os.path.exists(self.path): |
|
Base.metadata.create_all(engine) |
|
while self.network.asyncio_loop.is_running(): |
|
try: |
|
future, func, args, kwargs = self.db_requests.get(timeout=0.1) |
|
except queue.Empty: |
|
continue |
|
try: |
|
result = func(self, *args, **kwargs) |
|
except BaseException as e: |
|
future.set_exception(e) |
|
continue |
|
future.set_result(result) |
|
# write |
|
self.DBSession.commit() |
|
self.print_error("SQL thread terminated") |
|
|
|
def sql(func): |
|
def wrapper(self, *args, **kwargs): |
|
f = concurrent.futures.Future() |
|
self.db_requests.put((f, func, args, kwargs)) |
|
return f.result(timeout=10) |
|
return wrapper |
|
|
|
@sql |
|
def get_sweep_tx(self, funding_outpoint, prev_txid): |
|
return [Transaction(r.tx) for r in self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.prev_txid==prev_txid).all()] |
|
|
|
@sql |
|
def list_sweep_tx(self): |
|
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) |
|
|
|
@sql |
|
def add_sweep_tx(self, funding_outpoint, prev_txid, tx): |
|
self.DBSession.add(SweepTx(funding_outpoint=funding_outpoint, prev_txid=prev_txid, tx=str(tx), txid=tx.txid())) |
|
self.DBSession.commit() |
|
|
|
@sql |
|
def num_sweep_tx(self, funding_outpoint): |
|
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() |
|
|
|
@sql |
|
def remove_sweep_tx(self, funding_outpoint): |
|
v = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint).all() |
|
self.DBSession.delete(v) |
|
self.DBSession.commit() |
|
|
|
@sql |
|
def add_channel_info(self, address, outpoint): |
|
self.DBSession.add(ChannelInfo(address=address, outpoint=outpoint)) |
|
self.DBSession.commit() |
|
|
|
@sql |
|
def remove_channel_info(self, address): |
|
v = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
self.DBSession.delete(v) |
|
self.DBSession.commit() |
|
|
|
@sql |
|
def has_channel_info(self, address): |
|
return bool(self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none()) |
|
|
|
@sql |
|
def get_channel_info(self, address): |
|
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.address==address).one_or_none() |
|
return r.outpoint if r else None |
|
|
|
@sql |
|
def list_channel_info(self): |
|
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()] |
|
|
|
|
|
class LNWatcher(AddressSynchronizer): |
|
verbosity_filter = 'W' |
|
|
|
def __init__(self, network: 'Network'): |
|
path = os.path.join(network.config.path, "watchtower_wallet") |
|
storage = WalletStorage(path) |
|
AddressSynchronizer.__init__(self, storage) |
|
self.config = network.config |
|
self.start_network(network) |
|
self.lock = threading.RLock() |
|
self.sweepstore = SweepStore(os.path.join(network.config.path, "watchtower_db"), network) |
|
self.network.register_callback(self.on_network_update, |
|
['network_updated', 'blockchain_updated', 'verified', 'wallet_updated']) |
|
self.set_remote_watchtower() |
|
# this maps funding_outpoints to ListenerItems, which have an event for when the watcher is done, |
|
# and a queue for seeing which txs are being published |
|
self.tx_progress = {} # type: Dict[str, ListenerItem] |
|
# status gets populated when we run |
|
self.channel_status = {} |
|
|
|
def get_channel_status(self, outpoint): |
|
return self.channel_status.get(outpoint, 'unknown') |
|
|
|
def set_remote_watchtower(self): |
|
watchtower_url = self.config.get('watchtower_url') |
|
self.watchtower = jsonrpclib.Server(watchtower_url) if watchtower_url else None |
|
self.watchtower_queue = asyncio.Queue() |
|
|
|
def with_watchtower(func): |
|
def wrapper(self, *args, **kwargs): |
|
if self.watchtower: |
|
self.watchtower_queue.put_nowait((func.__name__, args, kwargs)) |
|
return func(self, *args, **kwargs) |
|
return wrapper |
|
|
|
@ignore_exceptions |
|
@log_exceptions |
|
async def watchtower_task(self): |
|
self.print_error('watchtower task started') |
|
while True: |
|
name, args, kwargs = await self.watchtower_queue.get() |
|
if self.watchtower is None: |
|
continue |
|
func = getattr(self.watchtower, name) |
|
try: |
|
r = func(*args, **kwargs) |
|
self.print_error("watchtower answer", r) |
|
except: |
|
self.print_error('could not reach watchtower, will retry in 5s', name, args) |
|
await asyncio.sleep(5) |
|
await self.watchtower_queue.put((name, args, kwargs)) |
|
|
|
|
|
@with_watchtower |
|
def watch_channel(self, address, outpoint): |
|
self.add_address(address) |
|
with self.lock: |
|
if not self.sweepstore.has_channel_info(address): |
|
self.sweepstore.add_channel_info(address, outpoint) |
|
|
|
def unwatch_channel(self, address, funding_outpoint): |
|
self.print_error('unwatching', funding_outpoint) |
|
self.sweepstore.remove_sweep_tx(funding_outpoint) |
|
self.sweepstore.remove_channel_info(address) |
|
if funding_outpoint in self.tx_progress: |
|
self.tx_progress[funding_outpoint].all_done.set() |
|
|
|
@log_exceptions |
|
async def on_network_update(self, event, *args): |
|
if event in ('verified', 'wallet_updated'): |
|
if args[0] != self: |
|
return |
|
if not self.synchronizer: |
|
self.print_error("synchronizer not set yet") |
|
return |
|
if not self.synchronizer.is_up_to_date(): |
|
return |
|
for address, outpoint in self.sweepstore.list_channel_info(): |
|
await self.check_onchain_situation(address, outpoint) |
|
|
|
async def check_onchain_situation(self, address, funding_outpoint): |
|
keep_watching, spenders = self.inspect_tx_candidate(funding_outpoint, 0) |
|
funding_txid = funding_outpoint.split(':')[0] |
|
funding_height = self.get_tx_height(funding_txid) |
|
closing_txid = spenders.get(funding_outpoint) |
|
if closing_txid is None: |
|
self.network.trigger_callback('channel_open', funding_outpoint, funding_txid, funding_height) |
|
else: |
|
closing_height = self.get_tx_height(closing_txid) |
|
self.network.trigger_callback('channel_closed', funding_outpoint, spenders, funding_txid, funding_height, closing_txid, closing_height) |
|
await self.do_breach_remedy(funding_outpoint, spenders) |
|
if not keep_watching: |
|
self.unwatch_channel(address, funding_outpoint) |
|
else: |
|
self.print_error('we will keep_watching', funding_outpoint) |
|
|
|
def inspect_tx_candidate(self, outpoint, n): |
|
# FIXME: instead of stopping recursion at n == 2, |
|
# we should detect which outputs are HTLCs |
|
prev_txid, index = outpoint.split(':') |
|
txid = self.spent_outpoints[prev_txid].get(int(index)) |
|
result = {outpoint:txid} |
|
if txid is None: |
|
self.channel_status[outpoint] = 'open' |
|
self.print_error('keep watching because outpoint is unspent') |
|
return True, result |
|
keep_watching = (self.get_tx_mined_depth(txid) != TxMinedDepth.DEEP) |
|
if keep_watching: |
|
self.channel_status[outpoint] = 'closed (%d)' % self.get_tx_height(txid).conf |
|
self.print_error('keep watching because spending tx is not deep') |
|
else: |
|
self.channel_status[outpoint] = 'closed (deep)' |
|
|
|
tx = self.transactions[txid] |
|
for i, o in enumerate(tx.outputs()): |
|
if o.address not in self.get_addresses(): |
|
self.add_address(o.address) |
|
keep_watching = True |
|
elif n < 2: |
|
k, r = self.inspect_tx_candidate(txid+':%d'%i, n+1) |
|
keep_watching |= k |
|
result.update(r) |
|
return keep_watching, result |
|
|
|
async def do_breach_remedy(self, funding_outpoint, spenders): |
|
for prevout, spender in spenders.items(): |
|
if spender is not None: |
|
continue |
|
prev_txid, prev_n = prevout.split(':') |
|
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prev_txid) |
|
for tx in sweep_txns: |
|
if not await self.broadcast_or_log(funding_outpoint, tx): |
|
self.print_error(tx.name, f'could not publish tx: {str(tx)}, prev_txid: {prev_txid}') |
|
|
|
async def broadcast_or_log(self, funding_outpoint, tx): |
|
height = self.get_tx_height(tx.txid()).height |
|
if height != TX_HEIGHT_LOCAL: |
|
return |
|
try: |
|
txid = await self.network.broadcast_transaction(tx) |
|
except Exception as e: |
|
self.print_error(f'broadcast: {tx.name}: failure: {repr(e)}') |
|
else: |
|
self.print_error(f'broadcast: {tx.name}: success. txid: {txid}') |
|
if funding_outpoint in self.tx_progress: |
|
await self.tx_progress[funding_outpoint].tx_queue.put(tx) |
|
return txid |
|
|
|
@with_watchtower |
|
def add_sweep_tx(self, funding_outpoint: str, prev_txid: str, tx_dict): |
|
tx = Transaction.from_dict(tx_dict) |
|
self.sweepstore.add_sweep_tx(funding_outpoint, prev_txid, tx) |
|
|
|
def get_tx_mined_depth(self, txid: str): |
|
if not txid: |
|
return TxMinedDepth.FREE |
|
tx_mined_depth = self.get_tx_height(txid) |
|
height, conf = tx_mined_depth.height, tx_mined_depth.conf |
|
if conf > 100: |
|
return TxMinedDepth.DEEP |
|
elif conf > 0: |
|
return TxMinedDepth.SHALLOW |
|
elif height in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT): |
|
return TxMinedDepth.MEMPOOL |
|
elif height == TX_HEIGHT_LOCAL: |
|
return TxMinedDepth.FREE |
|
elif height > 0 and conf == 0: |
|
# unverified but claimed to be mined |
|
return TxMinedDepth.MEMPOOL |
|
else: |
|
raise NotImplementedError()
|
|
|