Browse Source

optimize channel_db:

- use python objects mirrored by sql database
 - write sql to file asynchronously
 - the sql decorator is awaited in sweepstore, not in channel_db
master
ThomasV 7 years ago
parent
commit
f2d58d0e3f
  1. 534
      electrum/channel_db.py
  2. 9
      electrum/gui/qt/lightning_dialog.py
  3. 11
      electrum/lnaddr.py
  4. 12
      electrum/lnchannel.py
  5. 87
      electrum/lnpeer.py
  6. 19
      electrum/lnrouter.py
  7. 49
      electrum/lnwatcher.py
  8. 131
      electrum/lnworker.py
  9. 22
      electrum/sql_db.py
  10. 3
      electrum/tests/test_lnpeer.py
  11. 12
      electrum/tests/test_lnrouter.py

534
electrum/channel_db.py

@ -51,6 +51,7 @@ from .crypto import sha256d
from . import ecc from . import ecc
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH,
NotFoundChanAnnouncementForUpdate) NotFoundChanAnnouncementForUpdate)
from .lnverifier import verify_sig_for_channel_update
from .lnmsg import encode_msg from .lnmsg import encode_msg
if TYPE_CHECKING: if TYPE_CHECKING:
@ -70,85 +71,83 @@ Base = declarative_base()
FLAG_DISABLE = 1 << 1 FLAG_DISABLE = 1 << 1
FLAG_DIRECTION = 1 << 0 FLAG_DIRECTION = 1 << 0
class ChannelInfo(Base): class ChannelInfo(NamedTuple):
__tablename__ = 'channel_info' short_channel_id: bytes
short_channel_id = Column(String(64), primary_key=True) node1_id: bytes
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) node2_id: bytes
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) capacity_sat: int
capacity_sat = Column(Integer) msg_payload: bytes
msg_payload_hex = Column(String(1024), nullable=False) trusted: bool
trusted = Column(Boolean, nullable=False)
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload):
features = int.from_bytes(payload['features'], 'big') features = int.from_bytes(payload['features'], 'big')
validate_features(features) validate_features(features)
channel_id = payload['short_channel_id'].hex() channel_id = payload['short_channel_id']
node_id_1 = payload['node_id_1'].hex() node_id_1 = payload['node_id_1']
node_id_2 = payload['node_id_2'].hex() node_id_2 = payload['node_id_2']
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2]
msg_payload_hex = encode_msg('channel_announcement', **payload).hex() msg_payload = encode_msg('channel_announcement', **payload)
capacity_sat = None capacity_sat = None
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, return ChannelInfo(
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, short_channel_id = channel_id,
trusted = False) node1_id = node_id_1,
node2_id = node_id_2,
@property capacity_sat = capacity_sat,
def msg_payload(self): msg_payload = msg_payload,
return bytes.fromhex(self.msg_payload_hex) trusted = False)
class Policy(Base):
__tablename__ = 'policy' class Policy(NamedTuple):
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) key: bytes
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) cltv_expiry_delta: int
cltv_expiry_delta = Column(Integer, nullable=False) htlc_minimum_msat: int
htlc_minimum_msat = Column(Integer, nullable=False) htlc_maximum_msat: int
htlc_maximum_msat = Column(Integer) fee_base_msat: int
fee_base_msat = Column(Integer, nullable=False) fee_proportional_millionths: int
fee_proportional_millionths = Column(Integer, nullable=False) channel_flags: int
channel_flags = Column(Integer, nullable=False) timestamp: int
timestamp = Column(Integer, nullable=False)
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload):
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") return Policy(
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") key = payload['short_channel_id'] + payload['start_node'],
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"),
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"),
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None,
channel_flags = int.from_bytes(payload['channel_flags'], "big") fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"),
timestamp = int.from_bytes(payload['timestamp'], "big") fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"),
start_node = payload['start_node'].hex() channel_flags = int.from_bytes(payload['channel_flags'], "big"),
short_channel_id = payload['short_channel_id'].hex() timestamp = int.from_bytes(payload['timestamp'], "big")
)
return Policy(start_node=start_node,
short_channel_id=short_channel_id,
cltv_expiry_delta=cltv_expiry_delta,
htlc_minimum_msat=htlc_minimum_msat,
fee_base_msat=fee_base_msat,
fee_proportional_millionths=fee_proportional_millionths,
channel_flags=channel_flags,
timestamp=timestamp,
htlc_maximum_msat=htlc_maximum_msat)
def is_disabled(self): def is_disabled(self):
return self.channel_flags & FLAG_DISABLE return self.channel_flags & FLAG_DISABLE
class NodeInfo(Base): @property
__tablename__ = 'node_info' def short_channel_id(self):
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') return self.key[0:8]
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False) @property
alias = Column(String(64), nullable=False) def start_node(self):
return self.key[8:]
class NodeInfo(NamedTuple):
node_id: bytes
features: int
timestamp: int
alias: str
@staticmethod @staticmethod
def from_msg(payload): def from_msg(payload):
node_id = payload['node_id'].hex() node_id = payload['node_id']
features = int.from_bytes(payload['features'], "big") features = int.from_bytes(payload['features'], "big")
validate_features(features) validate_features(features)
addresses = NodeInfo.parse_addresses_field(payload['addresses']) addresses = NodeInfo.parse_addresses_field(payload['addresses'])
alias = payload['alias'].rstrip(b'\x00').hex() alias = payload['alias'].rstrip(b'\x00')
timestamp = int.from_bytes(payload['timestamp'], "big") timestamp = int.from_bytes(payload['timestamp'], "big")
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses]
@ -193,110 +192,136 @@ class NodeInfo(Base):
break break
return addresses return addresses
class Address(Base):
class Address(NamedTuple):
node_id: bytes
host: str
port: int
last_connected_date: int
class ChannelInfoBase(Base):
__tablename__ = 'channel_info'
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False)
capacity_sat = Column(Integer)
msg_payload = Column(String(1024), nullable=False)
trusted = Column(Boolean, nullable=False)
def to_nametuple(self):
return ChannelInfo(
short_channel_id=self.short_channel_id,
node1_id=self.node1_id,
node2_id=self.node2_id,
capacity_sat=self.capacity_sat,
msg_payload=self.msg_payload,
trusted=self.trusted
)
class PolicyBase(Base):
__tablename__ = 'policy'
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
cltv_expiry_delta = Column(Integer, nullable=False)
htlc_minimum_msat = Column(Integer, nullable=False)
htlc_maximum_msat = Column(Integer)
fee_base_msat = Column(Integer, nullable=False)
fee_proportional_millionths = Column(Integer, nullable=False)
channel_flags = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
def to_nametuple(self):
return Policy(
key=self.key,
cltv_expiry_delta=self.cltv_expiry_delta,
htlc_minimum_msat=self.htlc_minimum_msat,
htlc_maximum_msat=self.htlc_maximum_msat,
fee_base_msat= self.fee_base_msat,
fee_proportional_millionths = self.fee_proportional_millionths,
channel_flags=self.channel_flags,
timestamp=self.timestamp
)
class NodeInfoBase(Base):
__tablename__ = 'node_info'
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
features = Column(Integer, nullable=False)
timestamp = Column(Integer, nullable=False)
alias = Column(String(64), nullable=False)
class AddressBase(Base):
__tablename__ = 'address' __tablename__ = 'address'
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE')
host = Column(String(256), primary_key=True) host = Column(String(256), primary_key=True)
port = Column(Integer, primary_key=True) port = Column(Integer, primary_key=True)
last_connected_date = Column(Integer(), nullable=True) last_connected_date = Column(Integer(), nullable=True)
class ChannelDB(SqlDB): class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
path = os.path.join(get_headers_dir(network.config), 'channel_db') path = os.path.join(get_headers_dir(network.config), 'channel_db')
super().__init__(network, path, Base) super().__init__(network, path, Base, commit_interval=100)
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict]
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
self.update_counts() # initialized in load_data
self._channels = {}
self._policies = {}
self._nodes = {}
self._addresses = defaultdict(set)
self._channels_for_node = defaultdict(set)
@sql
def update_counts(self): def update_counts(self):
self._update_counts() self.num_channels = len(self._channels)
self.num_policies = len(self._policies)
self.num_nodes = len(self._nodes)
def _update_counts(self): def get_channel_ids(self):
self.num_channels = self.DBSession.query(ChannelInfo).count() return set(self._channels.keys())
self.num_policies = self.DBSession.query(Policy).count()
self.num_nodes = self.DBSession.query(NodeInfo).count()
@sql
def known_ids(self):
known = self.DBSession.query(ChannelInfo.short_channel_id).all()
return set(bfh(r.short_channel_id) for r in known)
@sql
def add_recent_peer(self, peer: LNPeerAddr): def add_recent_peer(self, peer: LNPeerAddr):
now = int(time.time()) now = int(time.time())
node_id = peer.pubkey.hex() node_id = peer.pubkey
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() self._addresses[node_id].add((peer.host, peer.port, now))
self.save_address(node_id, peer, now)
@sql
def save_address(self, node_id, peer, now):
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none()
if addr: if addr:
addr.last_connected_date = now addr.last_connected_date = now
else: else:
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now)
self.DBSession.add(addr) self.DBSession.add(addr)
self.DBSession.commit()
@sql
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes):
unshuffled = self.DBSession \
.query(NodeInfo) \
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \
.limit(200) \
.all()
return random.sample(unshuffled, len(unshuffled))
@sql def get_200_randomly_sorted_nodes_not_in(self, node_ids):
def nodes_get(self, node_id): unshuffled = set(self._nodes.keys()) - node_ids
return self.DBSession \ return random.sample(unshuffled, min(200, len(unshuffled)))
.query(NodeInfo) \
.filter_by(node_id = node_id.hex()) \
.one_or_none()
@sql
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]:
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() r = self._addresses.get(node_id)
if not r: if not r:
return None return None
addr = r[0] addr = sorted(list(r), key=lambda x: x[2])[0]
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) host, port, timestamp = addr
return LNPeerAddr(host, port, node_id)
@sql
def get_recent_peers(self): def get_recent_peers(self):
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() r = [self.get_last_good_address(x) for x in self._addresses.keys()]
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] r = r[-self.NUM_MAX_RECENT_PEERS:]
return r
@sql
def missing_channel_announcements(self) -> Set[int]:
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id)))
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all())
@sql def add_channel_announcement(self, msg_payloads, trusted=True):
def missing_channel_updates(self) -> Set[int]:
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id)))
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all())
@sql
def add_verified_channel_info(self, short_id, capacity):
# called from lnchannelverifier
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none()
channel_info.trusted = True
channel_info.capacity = capacity
self.DBSession.commit()
@sql
@profiler
def on_channel_announcement(self, msg_payloads, trusted=True):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
new_channels = {} added = 0
for msg in msg_payloads: for msg in msg_payloads:
short_channel_id = bh2u(msg['short_channel_id']) short_channel_id = msg['short_channel_id']
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): if short_channel_id in self._channels:
continue continue
if constants.net.rev_genesis_bytes() != msg['chain_hash']: if constants.net.rev_genesis_bytes() != msg['chain_hash']:
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash'])))
@ -306,24 +331,24 @@ class ChannelDB(SqlDB):
except UnknownEvenFeatureBits: except UnknownEvenFeatureBits:
self.logger.info("unknown feature bits") self.logger.info("unknown feature bits")
continue continue
channel_info.trusted = trusted #channel_info.trusted = trusted
new_channels[short_channel_id] = channel_info added += 1
self._channels[short_channel_id] = channel_info
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.save_channel(channel_info)
if not trusted: if not trusted:
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload)
for channel_info in new_channels.values():
self.DBSession.add(channel_info)
self.DBSession.commit()
self._update_counts()
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads)))
@sql self.update_counts()
def get_last_timestamp(self): self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads)))
return self._get_last_timestamp()
def _get_last_timestamp(self):
from sqlalchemy.sql import func #def add_verified_channel_info(self, short_id, capacity):
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() # # called from lnchannelverifier
return r.max_timestamp or 0 # channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none()
# channel_info.trusted = True
# channel_info.capacity = capacity
def print_change(self, old_policy, new_policy): def print_change(self, old_policy, new_policy):
# print what changed between policies # print what changed between policies
@ -340,89 +365,74 @@ class ChannelDB(SqlDB):
if old_policy.channel_flags != new_policy.channel_flags: if old_policy.channel_flags != new_policy.channel_flags:
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}')
@sql def add_channel_updates(self, payloads, max_age=None, verify=True):
def get_info_for_updates(self, payloads):
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads]
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all()
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list}
return channel_infos
@sql
def get_policies_for_updates(self, payloads):
out = {}
for payload in payloads:
short_channel_id = payload['short_channel_id'].hex()
start_node = payload['start_node'].hex()
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none()
if policy:
out[short_channel_id+start_node] = policy
return out
@profiler
def filter_channel_updates(self, payloads, max_age=None):
orphaned = [] # no channel announcement for channel update orphaned = [] # no channel announcement for channel update
expired = [] # update older than two weeks expired = [] # update older than two weeks
deprecated = [] # update older than database entry deprecated = [] # update older than database entry
good = {} # good updates good = [] # good updates
to_delete = [] # database entries to delete to_delete = [] # database entries to delete
# filter orphaned and expired first # filter orphaned and expired first
known = [] known = []
now = int(time.time()) now = int(time.time())
channel_infos = self.get_info_for_updates(payloads)
for payload in payloads: for payload in payloads:
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
timestamp = int.from_bytes(payload['timestamp'], "big") timestamp = int.from_bytes(payload['timestamp'], "big")
if max_age and now - timestamp > max_age: if max_age and now - timestamp > max_age:
expired.append(short_channel_id) expired.append(short_channel_id)
continue continue
channel_info = channel_infos.get(short_channel_id) channel_info = self._channels.get(short_channel_id)
if not channel_info: if not channel_info:
orphaned.append(short_channel_id) orphaned.append(short_channel_id)
continue continue
flags = int.from_bytes(payload['channel_flags'], 'big') flags = int.from_bytes(payload['channel_flags'], 'big')
direction = flags & FLAG_DIRECTION direction = flags & FLAG_DIRECTION
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id
payload['start_node'] = bfh(start_node) payload['start_node'] = start_node
known.append(payload) known.append(payload)
# compare updates to existing database entries # compare updates to existing database entries
old_policies = self.get_policies_for_updates(known)
for payload in known: for payload in known:
timestamp = int.from_bytes(payload['timestamp'], "big") timestamp = int.from_bytes(payload['timestamp'], "big")
start_node = payload['start_node'] start_node = payload['start_node']
short_channel_id = payload['short_channel_id'] short_channel_id = payload['short_channel_id']
key = (short_channel_id+start_node).hex() key = (start_node, short_channel_id)
old_policy = old_policies.get(key) old_policy = self._policies.get(key)
if old_policy: if old_policy and timestamp <= old_policy.timestamp:
if timestamp <= old_policy.timestamp: deprecated.append(short_channel_id)
deprecated.append(short_channel_id) continue
else: good.append(payload)
good[key] = payload if verify:
to_delete.append(old_policy) self.verify_channel_update(payload)
else: policy = Policy.from_msg(payload)
good[key] = payload self._policies[key] = policy
good = list(good.values()) self.save_policy(policy)
#
self.update_counts()
return orphaned, expired, deprecated, good, to_delete return orphaned, expired, deprecated, good, to_delete
def add_channel_update(self, payload): def add_channel_update(self, payload):
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False)
assert len(good) == 1 assert len(good) == 1
self.update_policies(good, to_delete)
@sql @sql
@profiler def save_policy(self, policy):
def update_policies(self, to_add, to_delete): self.DBSession.execute(PolicyBase.__table__.insert().values(policy))
for policy in to_delete:
self.DBSession.delete(policy)
self.DBSession.commit()
for payload in to_add:
policy = Policy.from_msg(payload)
self.DBSession.add(policy)
self.DBSession.commit()
self._update_counts()
@sql @sql
@profiler def delete_policy(self, short_channel_id, node_id):
def on_node_announcement(self, msg_payloads): self.DBSession.execute(PolicyBase.__table__.delete().values(policy))
@sql
def save_channel(self, channel_info):
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info))
def verify_channel_update(self, payload):
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
def add_node_announcement(self, msg_payloads):
if type(msg_payloads) is dict: if type(msg_payloads) is dict:
msg_payloads = [msg_payloads] msg_payloads = [msg_payloads]
old_addr = None old_addr = None
@ -435,29 +445,35 @@ class ChannelDB(SqlDB):
continue continue
node_id = node_info.node_id node_id = node_info.node_id
# Ignore node if it has no associated channel (DoS protection) # Ignore node if it has no associated channel (DoS protection)
# FIXME this is slow if node_id not in self._channels_for_node:
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id)
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0:
#self.logger.info('ignoring orphan node_announcement') #self.logger.info('ignoring orphan node_announcement')
continue continue
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() node = self._nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp: if node and node.timestamp >= node_info.timestamp:
continue continue
node = new_nodes.get(node_id) node = new_nodes.get(node_id)
if node and node.timestamp >= node_info.timestamp: if node and node.timestamp >= node_info.timestamp:
continue continue
new_nodes[node_id] = node_info # save
self._nodes[node_id] = node_info
self.save_node(node_info)
for addr in node_addresses: for addr in node_addresses:
new_addresses[(addr.node_id,addr.host,addr.port)] = addr self._addresses[node_id].add((addr.host, addr.port, 0))
self.save_node_addresses(node_id, node_addresses)
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads)))
for node_info in new_nodes.values(): self.update_counts()
self.DBSession.add(node_info)
for new_addr in new_addresses.values(): @sql
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() def save_node_addresses(self, node_if, node_addresses):
for new_addr in node_addresses:
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none()
if not old_addr: if not old_addr:
self.DBSession.add(new_addr) self.DBSession.execute(AddressBase.__table__.insert().values(new_addr))
self.DBSession.commit()
self._update_counts() @sql
def save_node(self, node_info):
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info))
def get_routing_policy_for_channel(self, start_node_id: bytes, def get_routing_policy_for_channel(self, start_node_id: bytes,
short_channel_id: bytes) -> Optional[bytes]: short_channel_id: bytes) -> Optional[bytes]:
@ -470,41 +486,28 @@ class ChannelDB(SqlDB):
return None return None
return Policy.from_msg(msg) # won't actually be written to DB return Policy.from_msg(msg) # won't actually be written to DB
@sql
@profiler
def get_old_policies(self, delta): def get_old_policies(self, delta):
timestamp = int(time.time()) - delta now = int(time.time())
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta)
return old_policies.distinct().count()
@sql
@profiler
def prune_old_policies(self, delta): def prune_old_policies(self, delta):
# note: delete queries are order sensitive l = self.get_old_policies(delta)
timestamp = int(time.time()) - delta for k in l:
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) self._policies.pop(k)
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) if l:
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) self.logger.info(f'Deleting {len(l)} old policies')
self.DBSession.execute(delete_old_channels)
self.DBSession.execute(delete_old_policies)
self.DBSession.commit()
self._update_counts()
@sql
@profiler
def get_orphaned_channels(self): def get_orphaned_channels(self):
subquery = self.DBSession.query(Policy.short_channel_id) ids = set(x[1] for x in self._policies.keys())
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) return list(x for x in self._channels.keys() if x not in ids)
return orphaned.count()
@sql
@profiler
def prune_orphaned_channels(self): def prune_orphaned_channels(self):
subquery = self.DBSession.query(Policy.short_channel_id) l = self.get_orphaned_channels()
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) for k in l:
self.DBSession.execute(delete_orphaned) self._channels.pop(k)
self.DBSession.commit() self.update_counts()
self._update_counts() if l:
self.logger.info(f'Deleting {len(l)} orphaned channels')
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes):
if not verify_sig_for_channel_update(msg_payload, start_node_id): if not verify_sig_for_channel_update(msg_payload, start_node_id):
@ -513,67 +516,27 @@ class ChannelDB(SqlDB):
msg_payload['start_node'] = start_node_id msg_payload['start_node'] = start_node_id
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload
@sql
def remove_channel(self, short_channel_id): def remove_channel(self, short_channel_id):
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() self._channels.pop(short_channel_id, None)
if not r:
return
self.DBSession.delete(r)
self.DBSession.commit()
def print_graph(self, full_ids=False):
# used for debugging.
# FIXME there is a race here - iterables could change size from another thread
def other_node_id(node_id, channel_id):
channel_info = self.get_channel_info(channel_id)
if node_id == channel_info.node1_id:
other = channel_info.node2_id
else:
other = channel_info.node1_id
return other if full_ids else other[-4:]
print_msg('nodes')
for node in self.DBSession.query(NodeInfo).all():
print_msg(node)
print_msg('channels')
for channel_info in self.DBSession.query(ChannelInfo).all():
short_channel_id = channel_info.short_channel_id
node1 = channel_info.node1_id
node2 = channel_info.node2_id
direction1 = self.get_policy_for_node(channel_info, node1) is not None
direction2 = self.get_policy_for_node(channel_info, node2) is not None
if direction1 and direction2:
direction = 'both'
elif direction1:
direction = 'forward'
elif direction2:
direction = 'backward'
else:
direction = 'none'
print_msg('{}: {}, {}, {}'
.format(bh2u(short_channel_id),
bh2u(node1) if full_ids else bh2u(node1[-4:]),
bh2u(node2) if full_ids else bh2u(node2[-4:]),
direction))
def get_node_addresses(self, node_id):
@sql return self._addresses.get(node_id)
def get_node_addresses(self, node_info):
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all()
@sql @sql
@profiler @profiler
def load_data(self): def load_data(self):
r = self.DBSession.query(ChannelInfo).all() for x in self.DBSession.query(AddressBase).all():
self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0)))
r = self.DBSession.query(Policy).filter_by().all() for x in self.DBSession.query(ChannelInfoBase).all():
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) self._channels[x.short_channel_id] = x.to_nametuple()
self._channels_for_node = defaultdict(set) for x in self.DBSession.query(PolicyBase).filter_by().all():
p = x.to_nametuple()
self._policies[(p.start_node, p.short_channel_id)] = p
for channel_info in self._channels.values(): for channel_info in self._channels.values():
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id)
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id)
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}')
self.update_counts()
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']:
return self._policies.get((node_id, short_channel_id)) return self._policies.get((node_id, short_channel_id))
@ -584,6 +547,3 @@ class ChannelDB(SqlDB):
def get_channels_for_node(self, node_id) -> Set[bytes]: def get_channels_for_node(self, node_id) -> Set[bytes]:
"""Returns the set of channels that have node_id as one of the endpoints.""" """Returns the set of channels that have node_id as one of the endpoints."""
return self._channels_for_node.get(node_id) or set() return self._channels_for_node.get(node_id) or set()

9
electrum/gui/qt/lightning_dialog.py

@ -56,10 +56,11 @@ class WatcherList(MyTreeView):
return return
self.model().clear() self.model().clear()
self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')}) self.update_headers({0:_('Outpoint'), 1:_('Tx'), 2:_('Status')})
sweepstore = self.parent.lnwatcher.sweepstore lnwatcher = self.parent.lnwatcher
for outpoint in sweepstore.list_sweep_tx(): l = lnwatcher.list_sweep_tx()
n = sweepstore.get_num_tx(outpoint) for outpoint in l:
status = self.parent.lnwatcher.get_channel_status(outpoint) n = lnwatcher.get_num_tx(outpoint)
status = lnwatcher.get_channel_status(outpoint)
items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]] items = [QStandardItem(e) for e in [outpoint, "%d"%n, status]]
self.model().insertRow(self.model().rowCount(), items) self.model().insertRow(self.model().rowCount(), items)

11
electrum/lnaddr.py

@ -258,14 +258,21 @@ class LnAddr(object):
def get_min_final_cltv_expiry(self) -> int: def get_min_final_cltv_expiry(self) -> int:
return self._min_final_cltv_expiry return self._min_final_cltv_expiry
def get_description(self): def get_tag(self, tag):
description = '' description = ''
for k,v in self.tags: for k,v in self.tags:
if k == 'd': if k == tag:
description = v description = v
break break
return description return description
def get_description(self):
return self.get_tag('d')
def get_expiry(self):
return int(self.get_tag('x') or '3600')
def lndecode(a, verbose=False, expected_hrp=None): def lndecode(a, verbose=False, expected_hrp=None):
if expected_hrp is None: if expected_hrp is None:

12
electrum/lnchannel.py

@ -163,8 +163,6 @@ class Channel(Logger):
self._is_funding_txo_spent = None # "don't know" self._is_funding_txo_spent = None # "don't know"
self._state = None self._state = None
self.set_state('DISCONNECTED') self.set_state('DISCONNECTED')
self.lnwatcher = None
self.local_commitment = None self.local_commitment = None
self.remote_commitment = None self.remote_commitment = None
self.sweep_info = None self.sweep_info = None
@ -453,13 +451,10 @@ class Channel(Logger):
return secret, point return secret, point
def process_new_revocation_secret(self, per_commitment_secret: bytes): def process_new_revocation_secret(self, per_commitment_secret: bytes):
if not self.lnwatcher:
return
outpoint = self.funding_outpoint.to_str() outpoint = self.funding_outpoint.to_str()
ctx = self.remote_commitment_to_be_revoked # FIXME can't we just reconstruct it? ctx = self.remote_commitment_to_be_revoked # FIXME can't we just reconstruct it?
sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address) sweeptxs = create_sweeptxs_for_their_revoked_ctx(self, ctx, per_commitment_secret, self.sweep_address)
for tx in sweeptxs: return sweeptxs
self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
def receive_revocation(self, revocation: RevokeAndAck): def receive_revocation(self, revocation: RevokeAndAck):
self.logger.info("receive_revocation") self.logger.info("receive_revocation")
@ -477,9 +472,10 @@ class Channel(Logger):
# be robust to exceptions raised in lnwatcher # be robust to exceptions raised in lnwatcher
try: try:
self.process_new_revocation_secret(revocation.per_commitment_secret) sweeptxs = self.process_new_revocation_secret(revocation.per_commitment_secret)
except Exception as e: except Exception as e:
self.logger.info("Could not process revocation secret: {}".format(repr(e))) self.logger.info("Could not process revocation secret: {}".format(repr(e)))
sweeptxs = []
##### start applying fee/htlc changes ##### start applying fee/htlc changes
@ -505,6 +501,8 @@ class Channel(Logger):
self.set_remote_commitment() self.set_remote_commitment()
self.remote_commitment_to_be_revoked = prev_remote_commitment self.remote_commitment_to_be_revoked = prev_remote_commitment
# return sweep transactions for watchtower
return sweeptxs
def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None): def balance(self, whose, *, ctx_owner=HTLCOwner.LOCAL, ctn=None):
""" """

87
electrum/lnpeer.py

@ -42,7 +42,6 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc,
MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY) MAXIMUM_REMOTE_TO_SELF_DELAY_ACCEPTED, RemoteMisbehaving, DEFAULT_TO_SELF_DELAY)
from .lntransport import LNTransport, LNTransportBase from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg from .lnmsg import encode_msg, decode_msg
from .lnverifier import verify_sig_for_channel_update
from .interface import GracefulDisconnect from .interface import GracefulDisconnect
if TYPE_CHECKING: if TYPE_CHECKING:
@ -242,22 +241,20 @@ class Peer(Logger):
# channel announcements # channel announcements
for chan_anns_chunk in chunks(chan_anns, 300): for chan_anns_chunk in chunks(chan_anns, 300):
self.verify_channel_announcements(chan_anns_chunk) self.verify_channel_announcements(chan_anns_chunk)
self.channel_db.on_channel_announcement(chan_anns_chunk) self.channel_db.add_channel_announcement(chan_anns_chunk)
# node announcements # node announcements
for node_anns_chunk in chunks(node_anns, 100): for node_anns_chunk in chunks(node_anns, 100):
self.verify_node_announcements(node_anns_chunk) self.verify_node_announcements(node_anns_chunk)
self.channel_db.on_node_announcement(node_anns_chunk) self.channel_db.add_node_announcement(node_anns_chunk)
# channel updates # channel updates
for chan_upds_chunk in chunks(chan_upds, 1000): for chan_upds_chunk in chunks(chan_upds, 1000):
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates(chan_upds_chunk, orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates(
max_age=self.network.lngossip.max_age) chan_upds_chunk, max_age=self.network.lngossip.max_age)
if orphaned: if orphaned:
self.logger.info(f'adding {len(orphaned)} unknown channel ids') self.logger.info(f'adding {len(orphaned)} unknown channel ids')
self.network.lngossip.add_new_ids(orphaned) await self.network.lngossip.add_new_ids(orphaned)
if good: if good:
self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}') self.logger.debug(f'on_channel_update: {len(good)}/{len(chan_upds_chunk)}')
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
# refresh gui # refresh gui
if chan_anns or node_anns or chan_upds: if chan_anns or node_anns or chan_upds:
self.network.lngossip.refresh_gui() self.network.lngossip.refresh_gui()
@ -279,14 +276,6 @@ class Peer(Logger):
if not ecc.verify_signature(pubkey, signature, h): if not ecc.verify_signature(pubkey, signature, h):
raise Exception('signature failed') raise Exception('signature failed')
def verify_channel_updates(self, chan_upds):
for payload in chan_upds:
short_channel_id = payload['short_channel_id']
if constants.net.rev_genesis_bytes() != payload['chain_hash']:
raise Exception('wrong chain hash')
if not verify_sig_for_channel_update(payload, payload['start_node']):
raise BaseException('verify error')
async def query_gossip(self): async def query_gossip(self):
try: try:
await asyncio.wait_for(self.initialized.wait(), 10) await asyncio.wait_for(self.initialized.wait(), 10)
@ -298,7 +287,7 @@ class Peer(Logger):
except asyncio.TimeoutError as e: except asyncio.TimeoutError as e:
raise GracefulDisconnect("query_channel_range timed out") from e raise GracefulDisconnect("query_channel_range timed out") from e
self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete)) self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete))
self.lnworker.add_new_ids(ids) await self.lnworker.add_new_ids(ids)
while True: while True:
todo = self.lnworker.get_ids_to_query() todo = self.lnworker.get_ids_to_query()
if not todo: if not todo:
@ -658,7 +647,7 @@ class Peer(Logger):
) )
chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig) chan.open_with_first_pcp(payload['first_per_commitment_point'], remote_sig)
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) await self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.lnworker.on_channels_updated() self.lnworker.on_channels_updated()
while True: while True:
try: try:
@ -862,8 +851,6 @@ class Peer(Logger):
bitcoin_key_2=bitcoin_keys[1] bitcoin_key_2=bitcoin_keys[1]
) )
print("SENT CHANNEL ANNOUNCEMENT")
def mark_open(self, chan: Channel): def mark_open(self, chan: Channel):
assert chan.short_channel_id is not None assert chan.short_channel_id is not None
if chan.get_state() == "OPEN": if chan.get_state() == "OPEN":
@ -872,6 +859,10 @@ class Peer(Logger):
assert chan.config[LOCAL].funding_locked_received assert chan.config[LOCAL].funding_locked_received
chan.set_state("OPEN") chan.set_state("OPEN")
self.network.trigger_callback('channel', chan) self.network.trigger_callback('channel', chan)
asyncio.ensure_future(self.add_own_channel(chan))
self.logger.info("CHANNEL OPENING COMPLETED")
async def add_own_channel(self, chan):
# add channel to database # add channel to database
bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey] bitcoin_keys = [chan.config[LOCAL].multisig_key.pubkey, chan.config[REMOTE].multisig_key.pubkey]
sorted_node_ids = list(sorted(self.node_ids)) sorted_node_ids = list(sorted(self.node_ids))
@ -887,7 +878,7 @@ class Peer(Logger):
# that the remote sends, even if the channel was not announced # that the remote sends, even if the channel was not announced
# (from BOLT-07: "MAY create a channel_update to communicate the channel # (from BOLT-07: "MAY create a channel_update to communicate the channel
# parameters to the final node, even though the channel has not yet been announced") # parameters to the final node, even though the channel has not yet been announced")
self.channel_db.on_channel_announcement( self.channel_db.add_channel_announcement(
{ {
"short_channel_id": chan.short_channel_id, "short_channel_id": chan.short_channel_id,
"node_id_1": node_ids[0], "node_id_1": node_ids[0],
@ -922,8 +913,6 @@ class Peer(Logger):
if pending_channel_update: if pending_channel_update:
self.channel_db.add_channel_update(pending_channel_update) self.channel_db.add_channel_update(pending_channel_update)
self.logger.info("CHANNEL OPENING COMPLETED")
def send_announcement_signatures(self, chan: Channel): def send_announcement_signatures(self, chan: Channel):
bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey, bitcoin_keys = [chan.config[REMOTE].multisig_key.pubkey,
@ -962,36 +951,34 @@ class Peer(Logger):
def on_update_fail_htlc(self, payload): def on_update_fail_htlc(self, payload):
channel_id = payload["channel_id"] channel_id = payload["channel_id"]
htlc_id = int.from_bytes(payload["id"], "big") htlc_id = int.from_bytes(payload["id"], "big")
key = (channel_id, htlc_id)
try:
route = self.attempted_route[key]
except KeyError:
# the remote might try to fail an htlc after we restarted...
# attempted_route is not persisted, so we will get here then
self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
else:
try:
self._handle_error_code_from_failed_htlc(payload["reason"], route, channel_id, htlc_id)
except Exception:
# exceptions are suppressed as failing to handle an error code
# should not block us from removing the htlc
traceback.print_exc(file=sys.stderr)
# process update_fail_htlc on channel
chan = self.channels[channel_id] chan = self.channels[channel_id]
chan.receive_fail_htlc(htlc_id) chan.receive_fail_htlc(htlc_id)
local_ctn = chan.get_current_ctn(LOCAL) local_ctn = chan.get_current_ctn(LOCAL)
asyncio.ensure_future(self._on_update_fail_htlc(chan, htlc_id, local_ctn)) asyncio.ensure_future(self._handle_error_code_from_failed_htlc(payload, channel_id, htlc_id))
asyncio.ensure_future(self._on_update_fail_htlc(channel_id, htlc_id, local_ctn))
@log_exceptions @log_exceptions
async def _on_update_fail_htlc(self, chan, htlc_id, local_ctn): async def _on_update_fail_htlc(self, channel_id, htlc_id, local_ctn):
chan = self.channels[channel_id]
await self.await_local(chan, local_ctn) await self.await_local(chan, local_ctn)
self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False) self.lnworker.pending_payments[(chan.short_channel_id, htlc_id)].set_result(False)
def _handle_error_code_from_failed_htlc(self, error_reason, route: List['RouteEdge'], channel_id, htlc_id): @log_exceptions
async def _handle_error_code_from_failed_htlc(self, payload, channel_id, htlc_id):
chan = self.channels[channel_id] chan = self.channels[channel_id]
failure_msg, sender_idx = decode_onion_error(error_reason, key = (channel_id, htlc_id)
[x.node_id for x in route], try:
chan.onion_keys[htlc_id]) route = self.attempted_route[key]
except KeyError:
# the remote might try to fail an htlc after we restarted...
# attempted_route is not persisted, so we will get here then
self.logger.info("UPDATE_FAIL_HTLC. cannot decode! attempted route is MISSING. {}".format(key))
return
error_reason = payload["reason"]
failure_msg, sender_idx = decode_onion_error(
error_reason,
[x.node_id for x in route],
chan.onion_keys[htlc_id])
code, data = failure_msg.code, failure_msg.data code, data = failure_msg.code, failure_msg.data
self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}") self.logger.info(f"UPDATE_FAIL_HTLC {repr(code)} {data}")
self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}") self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}")
@ -1009,11 +996,9 @@ class Peer(Logger):
channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:] channel_update = (258).to_bytes(length=2, byteorder="big") + data[offset:]
message_type, payload = decode_msg(channel_update) message_type, payload = decode_msg(channel_update)
payload['raw'] = channel_update payload['raw'] = channel_update
orphaned, expired, deprecated, good, to_delete = self.channel_db.filter_channel_updates([payload]) orphaned, expired, deprecated, good, to_delete = self.channel_db.add_channel_updates([payload])
blacklist = False blacklist = False
if good: if good:
self.verify_channel_updates(good)
self.channel_db.update_policies(good, to_delete)
self.logger.info("applied channel update on our db") self.logger.info("applied channel update on our db")
elif orphaned: elif orphaned:
# maybe it is a private channel (and data in invoice was outdated) # maybe it is a private channel (and data in invoice was outdated)
@ -1276,11 +1261,17 @@ class Peer(Logger):
self.logger.info("on_revoke_and_ack") self.logger.info("on_revoke_and_ack")
channel_id = payload["channel_id"] channel_id = payload["channel_id"]
chan = self.channels[channel_id] chan = self.channels[channel_id]
chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])) sweeptxs = chan.receive_revocation(RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]))
self._remote_changed_events[chan.channel_id].set() self._remote_changed_events[chan.channel_id].set()
self._remote_changed_events[chan.channel_id].clear() self._remote_changed_events[chan.channel_id].clear()
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
asyncio.ensure_future(self._on_revoke_and_ack(chan, sweeptxs))
async def _on_revoke_and_ack(self, chan, sweeptxs):
outpoint = chan.funding_outpoint.to_str()
for tx in sweeptxs:
await self.lnwatcher.add_sweep_tx(outpoint, tx.prevout(0), str(tx))
def on_update_fee(self, payload): def on_update_fee(self, payload):
channel_id = payload["channel_id"] channel_id = payload["channel_id"]

19
electrum/lnrouter.py

@ -37,7 +37,7 @@ import binascii
import base64 import base64
from . import constants from . import constants
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks from .util import bh2u, profiler, get_headers_dir, is_ip_address, list_enabled_bits, print_msg, chunks
from .logging import Logger from .logging import Logger
from .storage import JsonDB from .storage import JsonDB
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update
@ -169,7 +169,6 @@ class LNPathFinder(Logger):
To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1]; To get from node ret[n][0] to ret[n+1][0], use channel ret[n+1][1];
i.e. an element reads as, "to get to node_id, travel through short_channel_id" i.e. an element reads as, "to get to node_id, travel through short_channel_id"
""" """
self.channel_db.load_data()
assert type(nodeA) is bytes assert type(nodeA) is bytes
assert type(nodeB) is bytes assert type(nodeB) is bytes
assert type(invoice_amount_msat) is int assert type(invoice_amount_msat) is int
@ -195,11 +194,12 @@ class LNPathFinder(Logger):
else: # payment incoming, on our channel. (funny business, cycle weirdness) else: # payment incoming, on our channel. (funny business, cycle weirdness)
assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode)) assert edge_endnode == nodeA, (bh2u(edge_startnode), bh2u(edge_endnode))
pass # TODO? pass # TODO?
edge_cost, fee_for_edge_msat = self._edge_cost(edge_channel_id, edge_cost, fee_for_edge_msat = self._edge_cost(
start_node=edge_startnode, edge_channel_id,
end_node=edge_endnode, start_node=edge_startnode,
payment_amt_msat=amount_msat, end_node=edge_endnode,
ignore_costs=(edge_startnode == nodeA)) payment_amt_msat=amount_msat,
ignore_costs=(edge_startnode == nodeA))
alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost alt_dist_to_neighbour = distance_from_start[edge_endnode] + edge_cost
if alt_dist_to_neighbour < distance_from_start[edge_startnode]: if alt_dist_to_neighbour < distance_from_start[edge_startnode]:
distance_from_start[edge_startnode] = alt_dist_to_neighbour distance_from_start[edge_startnode] = alt_dist_to_neighbour
@ -219,9 +219,10 @@ class LNPathFinder(Logger):
continue continue
for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode): for edge_channel_id in self.channel_db.get_channels_for_node(edge_endnode):
assert type(edge_channel_id) is bytes assert type(edge_channel_id) is bytes
if edge_channel_id in self.blacklist: continue if edge_channel_id in self.blacklist:
continue
channel_info = self.channel_db.get_channel_info(edge_channel_id) channel_info = self.channel_db.get_channel_info(edge_channel_id)
edge_startnode = bfh(channel_info.node2_id) if bfh(channel_info.node1_id) == edge_endnode else bfh(channel_info.node1_id) edge_startnode = channel_info.node2_id if channel_info.node1_id == edge_endnode else channel_info.node1_id
inspect_edge() inspect_edge()
else: else:
return None # no path found return None # no path found

49
electrum/lnwatcher.py

@ -70,11 +70,11 @@ class SweepStore(SqlDB):
@sql @sql
def get_tx_by_index(self, funding_outpoint, index): def get_tx_by_index(self, funding_outpoint, index):
r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none() r = self.DBSession.query(SweepTx).filter(SweepTx.funding_outpoint==funding_outpoint, SweepTx.index==index).one_or_none()
return r.prevout, bh2u(r.tx) return str(r.prevout), bh2u(r.tx)
@sql @sql
def list_sweep_tx(self): def list_sweep_tx(self):
return set(r.funding_outpoint for r in self.DBSession.query(SweepTx).all()) return set(str(r.funding_outpoint) for r in self.DBSession.query(SweepTx).all())
@sql @sql
def add_sweep_tx(self, funding_outpoint, prevout, tx): def add_sweep_tx(self, funding_outpoint, prevout, tx):
@ -84,7 +84,7 @@ class SweepStore(SqlDB):
@sql @sql
def get_num_tx(self, funding_outpoint): def get_num_tx(self, funding_outpoint):
return self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count() return int(self.DBSession.query(SweepTx).filter(funding_outpoint==funding_outpoint).count())
@sql @sql
def remove_sweep_tx(self, funding_outpoint): def remove_sweep_tx(self, funding_outpoint):
@ -111,11 +111,11 @@ class SweepStore(SqlDB):
@sql @sql
def get_address(self, outpoint): def get_address(self, outpoint):
r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none() r = self.DBSession.query(ChannelInfo).filter(ChannelInfo.outpoint==outpoint).one_or_none()
return r.address if r else None return str(r.address) if r else None
@sql @sql
def list_channel_info(self): def list_channel_info(self):
return [(r.address, r.outpoint) for r in self.DBSession.query(ChannelInfo).all()] return [(str(r.address), str(r.outpoint)) for r in self.DBSession.query(ChannelInfo).all()]
class LNWatcher(AddressSynchronizer): class LNWatcher(AddressSynchronizer):
@ -150,14 +150,21 @@ class LNWatcher(AddressSynchronizer):
self.watchtower_queue = asyncio.Queue() self.watchtower_queue = asyncio.Queue()
def get_num_tx(self, outpoint): def get_num_tx(self, outpoint):
return self.sweepstore.get_num_tx(outpoint) async def f():
return await self.sweepstore.get_num_tx(outpoint)
return self.network.run_from_another_thread(f())
def list_sweep_tx(self):
async def f():
return await self.sweepstore.list_sweep_tx()
return self.network.run_from_another_thread(f())
@ignore_exceptions @ignore_exceptions
@log_exceptions @log_exceptions
async def watchtower_task(self): async def watchtower_task(self):
self.logger.info('watchtower task started') self.logger.info('watchtower task started')
# initial check # initial check
for address, outpoint in self.sweepstore.list_channel_info(): for address, outpoint in await self.sweepstore.list_channel_info():
await self.watchtower_queue.put(outpoint) await self.watchtower_queue.put(outpoint)
while True: while True:
outpoint = await self.watchtower_queue.get() outpoint = await self.watchtower_queue.get()
@ -165,30 +172,30 @@ class LNWatcher(AddressSynchronizer):
continue continue
# synchronize with remote # synchronize with remote
try: try:
local_n = self.sweepstore.get_num_tx(outpoint) local_n = await self.sweepstore.get_num_tx(outpoint)
n = self.watchtower.get_num_tx(outpoint) n = self.watchtower.get_num_tx(outpoint)
if n == 0: if n == 0:
address = self.sweepstore.get_address(outpoint) address = await self.sweepstore.get_address(outpoint)
self.watchtower.add_channel(outpoint, address) self.watchtower.add_channel(outpoint, address)
self.logger.info("sending %d transactions to watchtower"%(local_n - n)) self.logger.info("sending %d transactions to watchtower"%(local_n - n))
for index in range(n, local_n): for index in range(n, local_n):
prevout, tx = self.sweepstore.get_tx_by_index(outpoint, index) prevout, tx = await self.sweepstore.get_tx_by_index(outpoint, index)
self.watchtower.add_sweep_tx(outpoint, prevout, tx) self.watchtower.add_sweep_tx(outpoint, prevout, tx)
except ConnectionRefusedError: except ConnectionRefusedError:
self.logger.info('could not reach watchtower, will retry in 5s') self.logger.info('could not reach watchtower, will retry in 5s')
await asyncio.sleep(5) await asyncio.sleep(5)
await self.watchtower_queue.put(outpoint) await self.watchtower_queue.put(outpoint)
def add_channel(self, outpoint, address): async def add_channel(self, outpoint, address):
self.add_address(address) self.add_address(address)
with self.lock: with self.lock:
if not self.sweepstore.has_channel(outpoint): if not await self.sweepstore.has_channel(outpoint):
self.sweepstore.add_channel(outpoint, address) await self.sweepstore.add_channel(outpoint, address)
def unwatch_channel(self, address, funding_outpoint): async def unwatch_channel(self, address, funding_outpoint):
self.logger.info(f'unwatching {funding_outpoint}') self.logger.info(f'unwatching {funding_outpoint}')
self.sweepstore.remove_sweep_tx(funding_outpoint) await self.sweepstore.remove_sweep_tx(funding_outpoint)
self.sweepstore.remove_channel(funding_outpoint) await self.sweepstore.remove_channel(funding_outpoint)
if funding_outpoint in self.tx_progress: if funding_outpoint in self.tx_progress:
self.tx_progress[funding_outpoint].all_done.set() self.tx_progress[funding_outpoint].all_done.set()
@ -202,7 +209,7 @@ class LNWatcher(AddressSynchronizer):
return return
if not self.synchronizer.is_up_to_date(): if not self.synchronizer.is_up_to_date():
return return
for address, outpoint in self.sweepstore.list_channel_info(): for address, outpoint in await self.sweepstore.list_channel_info():
await self.check_onchain_situation(address, outpoint) await self.check_onchain_situation(address, outpoint)
async def check_onchain_situation(self, address, funding_outpoint): async def check_onchain_situation(self, address, funding_outpoint):
@ -223,7 +230,7 @@ class LNWatcher(AddressSynchronizer):
closing_height, closing_tx) # FIXME sooo many args.. closing_height, closing_tx) # FIXME sooo many args..
await self.do_breach_remedy(funding_outpoint, spenders) await self.do_breach_remedy(funding_outpoint, spenders)
if not keep_watching: if not keep_watching:
self.unwatch_channel(address, funding_outpoint) await self.unwatch_channel(address, funding_outpoint)
else: else:
#self.logger.info(f'we will keep_watching {funding_outpoint}') #self.logger.info(f'we will keep_watching {funding_outpoint}')
pass pass
@ -260,7 +267,7 @@ class LNWatcher(AddressSynchronizer):
for prevout, spender in spenders.items(): for prevout, spender in spenders.items():
if spender is not None: if spender is not None:
continue continue
sweep_txns = self.sweepstore.get_sweep_tx(funding_outpoint, prevout) sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout)
for tx in sweep_txns: for tx in sweep_txns:
if not await self.broadcast_or_log(funding_outpoint, tx): if not await self.broadcast_or_log(funding_outpoint, tx):
self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}') self.logger.info(f'{tx.name} could not publish tx: {str(tx)}, prevout: {prevout}')
@ -279,8 +286,8 @@ class LNWatcher(AddressSynchronizer):
await self.tx_progress[funding_outpoint].tx_queue.put(tx) await self.tx_progress[funding_outpoint].tx_queue.put(tx)
return txid return txid
def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str): async def add_sweep_tx(self, funding_outpoint: str, prevout: str, tx: str):
self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx) await self.sweepstore.add_sweep_tx(funding_outpoint, prevout, tx)
if self.watchtower: if self.watchtower:
self.watchtower_queue.put_nowait(funding_outpoint) self.watchtower_queue.put_nowait(funding_outpoint)

131
electrum/lnworker.py

@ -108,12 +108,14 @@ class LNWorker(Logger):
@log_exceptions @log_exceptions
async def main_loop(self): async def main_loop(self):
# fixme: only lngossip should do that
await self.channel_db.load_data()
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)
now = time.time() now = time.time()
if len(self.peers) >= NUM_PEERS_TARGET: if len(self.peers) >= NUM_PEERS_TARGET:
continue continue
peers = self._get_next_peers_to_try() peers = await self._get_next_peers_to_try()
for peer in peers: for peer in peers:
last_tried = self._last_tried_peer.get(peer, 0) last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL < now: if last_tried + PEER_RETRY_INTERVAL < now:
@ -130,7 +132,8 @@ class LNWorker(Logger):
peer = Peer(self, node_id, transport) peer = Peer(self, node_id, transport)
await self.network.main_taskgroup.spawn(peer.main_loop()) await self.network.main_taskgroup.spawn(peer.main_loop())
self.peers[node_id] = peer self.peers[node_id] = peer
self.network.lngossip.refresh_gui() #if self.network.lngossip:
# self.network.lngossip.refresh_gui()
return peer return peer
def start_network(self, network: 'Network'): def start_network(self, network: 'Network'):
@ -148,7 +151,7 @@ class LNWorker(Logger):
self._add_peer(host, int(port), bfh(pubkey)), self._add_peer(host, int(port), bfh(pubkey)),
self.network.asyncio_loop) self.network.asyncio_loop)
def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]: async def _get_next_peers_to_try(self) -> Sequence[LNPeerAddr]:
now = time.time() now = time.time()
recent_peers = self.channel_db.get_recent_peers() recent_peers = self.channel_db.get_recent_peers()
# maintenance for last tried times # maintenance for last tried times
@ -158,19 +161,22 @@ class LNWorker(Logger):
del self._last_tried_peer[peer] del self._last_tried_peer[peer]
# first try from recent peers # first try from recent peers
for peer in recent_peers: for peer in recent_peers:
if peer.pubkey in self.peers: continue if peer.pubkey in self.peers:
if peer in self._last_tried_peer: continue continue
if peer in self._last_tried_peer:
continue
return [peer] return [peer]
# try random peer from graph # try random peer from graph
unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys()) unconnected_nodes = self.channel_db.get_200_randomly_sorted_nodes_not_in(self.peers.keys())
if unconnected_nodes: if unconnected_nodes:
for node in unconnected_nodes: for node_id in unconnected_nodes:
addrs = self.channel_db.get_node_addresses(node) addrs = self.channel_db.get_node_addresses(node_id)
if not addrs: if not addrs:
continue continue
host, port = self.choose_preferred_address(addrs) host, port, timestamp = self.choose_preferred_address(addrs)
peer = LNPeerAddr(host, port, bytes.fromhex(node.node_id)) peer = LNPeerAddr(host, port, node_id)
if peer in self._last_tried_peer: continue if peer in self._last_tried_peer:
continue
#self.logger.info('taking random ln peer from our channel db') #self.logger.info('taking random ln peer from our channel db')
return [peer] return [peer]
@ -223,15 +229,13 @@ class LNWorker(Logger):
def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]: def choose_preferred_address(addr_list: List[Tuple[str, int]]) -> Tuple[str, int]:
assert len(addr_list) >= 1 assert len(addr_list) >= 1
# choose first one that is an IP # choose first one that is an IP
for addr_in_db in addr_list: for host, port, timestamp in addr_list:
host = addr_in_db.host
port = addr_in_db.port
if is_ip_address(host): if is_ip_address(host):
return host, port return host, port, timestamp
# otherwise choose one at random # otherwise choose one at random
# TODO maybe filter out onion if not on tor? # TODO maybe filter out onion if not on tor?
choice = random.choice(addr_list) choice = random.choice(addr_list)
return choice.host, choice.port return choice
class LNGossip(LNWorker): class LNGossip(LNWorker):
@ -260,26 +264,19 @@ class LNGossip(LNWorker):
self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown) self.network.trigger_callback('ln_status', num_peers, num_nodes, known, unknown)
async def maintain_db(self): async def maintain_db(self):
n = self.channel_db.get_orphaned_channels() self.channel_db.prune_orphaned_channels()
if n:
self.logger.info(f'Deleting {n} orphaned channels')
self.channel_db.prune_orphaned_channels()
self.refresh_gui()
while True: while True:
n = self.channel_db.get_old_policies(self.max_age) self.channel_db.prune_old_policies(self.max_age)
if n: self.refresh_gui()
self.logger.info(f'Deleting {n} old channels')
self.channel_db.prune_old_policies(self.max_age)
self.refresh_gui()
await asyncio.sleep(5) await asyncio.sleep(5)
def add_new_ids(self, ids): async def add_new_ids(self, ids):
known = self.channel_db.known_ids() known = self.channel_db.get_channel_ids()
new = set(ids) - set(known) new = set(ids) - set(known)
self.unknown_ids.update(new) self.unknown_ids.update(new)
def get_ids_to_query(self): def get_ids_to_query(self):
N = 500 N = 100
l = list(self.unknown_ids) l = list(self.unknown_ids)
self.unknown_ids = set(l[N:]) self.unknown_ids = set(l[N:])
return l[0:N] return l[0:N]
@ -324,9 +321,10 @@ class LNWallet(LNWorker):
self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe self.network.register_callback(self.on_network_update, ['wallet_updated', 'network_updated', 'verified', 'fee']) # thread safe
self.network.register_callback(self.on_channel_open, ['channel_open']) self.network.register_callback(self.on_channel_open, ['channel_open'])
self.network.register_callback(self.on_channel_closed, ['channel_closed']) self.network.register_callback(self.on_channel_closed, ['channel_closed'])
for chan_id, chan in self.channels.items(): for chan_id, chan in self.channels.items():
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.network.lnwatcher.add_address(chan.get_funding_address())
chan.lnwatcher = network.lnwatcher
super().start_network(network) super().start_network(network)
for coro in [ for coro in [
self.maybe_listen(), self.maybe_listen(),
@ -494,7 +492,7 @@ class LNWallet(LNWorker):
chan = self.channel_by_txo(funding_outpoint) chan = self.channel_by_txo(funding_outpoint)
if not chan: if not chan:
return return
self.logger.debug(f'on_channel_open {funding_outpoint}') #self.logger.debug(f'on_channel_open {funding_outpoint}')
self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None self.channel_timestamps[bh2u(chan.channel_id)] = funding_txid, funding_height.height, funding_height.timestamp, None, None, None
self.storage.put('lightning_channel_timestamps', self.channel_timestamps) self.storage.put('lightning_channel_timestamps', self.channel_timestamps)
chan.set_funding_txo_spentness(False) chan.set_funding_txo_spentness(False)
@ -606,7 +604,8 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX') self.logger.info('REBROADCASTING CLOSING TX')
await self.force_close_channel(chan.channel_id) await self.force_close_channel(chan.channel_id)
async def _open_channel_coroutine(self, peer, local_amount_sat, push_sat, password): async def _open_channel_coroutine(self, connect_str, local_amount_sat, push_sat, password):
peer = await self.add_peer(connect_str)
# peer might just have been connected to # peer might just have been connected to
await asyncio.wait_for(peer.initialized.wait(), 5) await asyncio.wait_for(peer.initialized.wait(), 5)
chan = await peer.channel_establishment_flow( chan = await peer.channel_establishment_flow(
@ -615,24 +614,22 @@ class LNWallet(LNWorker):
push_msat=push_sat * 1000, push_msat=push_sat * 1000,
temp_channel_id=os.urandom(32)) temp_channel_id=os.urandom(32))
self.save_channel(chan) self.save_channel(chan)
self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.network.lnwatcher.add_address(chan.get_funding_address())
await self.network.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.on_channels_updated() self.on_channels_updated()
return chan return chan
def on_channels_updated(self): def on_channels_updated(self):
self.network.trigger_callback('channels') self.network.trigger_callback('channels')
def add_peer(self, connect_str, timeout=20): async def add_peer(self, connect_str, timeout=20):
node_id, rest = extract_nodeid(connect_str) node_id, rest = extract_nodeid(connect_str)
peer = self.peers.get(node_id) peer = self.peers.get(node_id)
if not peer: if not peer:
if rest is not None: if rest is not None:
host, port = split_host_port(rest) host, port = split_host_port(rest)
else: else:
node_info = self.network.channel_db.nodes_get(node_id) addrs = self.channel_db.get_node_addresses(node_id)
if not node_info:
raise ConnStringFormatError(_('Unknown node:') + ' ' + bh2u(node_id))
addrs = self.channel_db.get_node_addresses(node_info)
if len(addrs) == 0: if len(addrs) == 0:
raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id)) raise ConnStringFormatError(_('Don\'t know any addresses for node:') + ' ' + bh2u(node_id))
host, port = self.choose_preferred_address(addrs) host, port = self.choose_preferred_address(addrs)
@ -640,18 +637,12 @@ class LNWallet(LNWorker):
socket.getaddrinfo(host, int(port)) socket.getaddrinfo(host, int(port))
except socket.gaierror: except socket.gaierror:
raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)')) raise ConnStringFormatError(_('Hostname does not resolve (getaddrinfo failed)'))
peer_future = asyncio.run_coroutine_threadsafe( # add peer
self._add_peer(host, port, node_id), peer = await self._add_peer(host, port, node_id)
self.network.asyncio_loop)
try:
peer = peer_future.result(timeout)
except concurrent.futures.TimeoutError:
raise Exception(_("add_peer timed out"))
return peer return peer
def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20): def open_channel(self, connect_str, local_amt_sat, push_amt_sat, password=None, timeout=20):
peer = self.add_peer(connect_str, timeout) coro = self._open_channel_coroutine(connect_str, local_amt_sat, push_amt_sat, password)
coro = self._open_channel_coroutine(peer, local_amt_sat, push_amt_sat, password)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop) fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try: try:
chan = fut.result(timeout=timeout) chan = fut.result(timeout=timeout)
@ -664,6 +655,9 @@ class LNWallet(LNWorker):
Can be called from other threads Can be called from other threads
Raises timeout exception if htlc is not fulfilled Raises timeout exception if htlc is not fulfilled
""" """
addr = self._check_invoice(invoice, amount_sat)
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
fut = asyncio.run_coroutine_threadsafe( fut = asyncio.run_coroutine_threadsafe(
self._pay(invoice, attempts, amount_sat), self._pay(invoice, attempts, amount_sat),
self.network.asyncio_loop) self.network.asyncio_loop)
@ -680,8 +674,6 @@ class LNWallet(LNWorker):
async def _pay(self, invoice, attempts=1, amount_sat=None): async def _pay(self, invoice, attempts=1, amount_sat=None):
addr = self._check_invoice(invoice, amount_sat) addr = self._check_invoice(invoice, amount_sat)
self.save_invoice(addr.paymenthash, invoice, SENT, is_paid=False)
self.wallet.set_label(bh2u(addr.paymenthash), addr.get_description())
for i in range(attempts): for i in range(attempts):
route = await self._create_route_from_invoice(decoded_invoice=addr) route = await self._create_route_from_invoice(decoded_invoice=addr)
if not self.get_channel_by_short_id(route[0].short_channel_id): if not self.get_channel_by_short_id(route[0].short_channel_id):
@ -691,7 +683,7 @@ class LNWallet(LNWorker):
return True return True
return False return False
async def _pay_to_route(self, route, addr, pay_req): async def _pay_to_route(self, route, addr, invoice):
short_channel_id = route[0].short_channel_id short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id) chan = self.get_channel_by_short_id(short_channel_id)
if not chan: if not chan:
@ -713,6 +705,9 @@ class LNWallet(LNWorker):
raise InvoiceError("{}\n{}".format( raise InvoiceError("{}\n{}".format(
_("Invoice wants us to risk locking funds for unreasonably long."), _("Invoice wants us to risk locking funds for unreasonably long."),
f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}")) f"min_final_cltv_expiry: {addr.get_min_final_cltv_expiry()}"))
#now = int(time.time())
#if addr.date + addr.get_expiry() > now:
# raise InvoiceError(_('Invoice expired'))
return addr return addr
async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]: async def _create_route_from_invoice(self, decoded_invoice) -> List[RouteEdge]:
@ -730,11 +725,14 @@ class LNWallet(LNWorker):
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
for private_route in r_tags: for private_route in r_tags:
if len(private_route) == 0: continue if len(private_route) == 0:
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH: continue continue
if len(private_route) > NUM_MAX_EDGES_IN_PAYMENT_PATH:
continue
border_node_pubkey = private_route[0][0] border_node_pubkey = private_route[0][0]
path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels) path = self.network.path_finder.find_path_for_payment(self.node_keypair.pubkey, border_node_pubkey, amount_msat, channels)
if not path: continue if not path:
continue
route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey) route = self.network.path_finder.create_route_from_path(path, self.node_keypair.pubkey)
# we need to shift the node pubkey by one towards the destination: # we need to shift the node pubkey by one towards the destination:
private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey] private_route_nodes = [edge[0] for edge in private_route][1:] + [invoice_pubkey]
@ -770,10 +768,18 @@ class LNWallet(LNWorker):
return route return route
def add_invoice(self, amount_sat, message): def add_invoice(self, amount_sat, message):
coro = self._add_invoice_coro(amount_sat, message)
fut = asyncio.run_coroutine_threadsafe(coro, self.network.asyncio_loop)
try:
return fut.result(timeout=5)
except concurrent.futures.TimeoutError:
raise Exception(_("add_invoice timed out"))
async def _add_invoice_coro(self, amount_sat, message):
payment_preimage = os.urandom(32) payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage) payment_hash = sha256(payment_preimage)
amount_btc = amount_sat/Decimal(COIN) if amount_sat else None amount_btc = amount_sat/Decimal(COIN) if amount_sat else None
routing_hints = self._calc_routing_hints_for_invoice(amount_sat) routing_hints = await self._calc_routing_hints_for_invoice(amount_sat)
if not routing_hints: if not routing_hints:
self.logger.info("Warning. No routing hints added to invoice. " self.logger.info("Warning. No routing hints added to invoice. "
"Other clients will likely not be able to send to us.") "Other clients will likely not be able to send to us.")
@ -847,19 +853,20 @@ class LNWallet(LNWorker):
}) })
return out return out
def _calc_routing_hints_for_invoice(self, amount_sat): async def _calc_routing_hints_for_invoice(self, amount_sat):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""
self.channel_db.load_data()
routing_hints = [] routing_hints = []
with self.lock: with self.lock:
channels = list(self.channels.values()) channels = list(self.channels.values())
# note: currently we add *all* our channels; but this might be a privacy leak? # note: currently we add *all* our channels; but this might be a privacy leak?
for chan in channels: for chan in channels:
# check channel is open # check channel is open
if chan.get_state() != "OPEN": continue if chan.get_state() != "OPEN":
continue
# check channel has sufficient balance # check channel has sufficient balance
# FIXME because of on-chain fees of ctx, this check is insufficient # FIXME because of on-chain fees of ctx, this check is insufficient
if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat: continue if amount_sat and chan.balance(REMOTE) // 1000 < amount_sat:
continue
chan_id = chan.short_channel_id chan_id = chan.short_channel_id
assert type(chan_id) is bytes, chan_id assert type(chan_id) is bytes, chan_id
channel_info = self.channel_db.get_channel_info(chan_id) channel_info = self.channel_db.get_channel_info(chan_id)
@ -949,14 +956,10 @@ class LNWallet(LNWorker):
await self._add_peer(peer.host, peer.port, peer.pubkey) await self._add_peer(peer.host, peer.port, peer.pubkey)
return return
# try random address for node_id # try random address for node_id
node_info = self.channel_db.nodes_get(chan.node_id) addresses = self.channel_db.get_node_addresses(chan.node_id)
if not node_info:
return
addresses = self.channel_db.get_node_addresses(node_info)
if not addresses: if not addresses:
return return
adr_obj = random.choice(addresses) host, port, t = random.choice(list(addresses))
host, port = adr_obj.host, adr_obj.port
peer = LNPeerAddr(host, port, chan.node_id) peer = LNPeerAddr(host, port, chan.node_id)
last_tried = self._last_tried_peer.get(peer, 0) last_tried = self._last_tried_peer.get(peer, 0)
if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now: if last_tried + PEER_RETRY_INTERVAL_FOR_CHANNELS < now:

22
electrum/sql_db.py

@ -2,6 +2,7 @@ import os
import concurrent import concurrent
import queue import queue
import threading import threading
import asyncio
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
@ -18,28 +19,32 @@ def sql(func):
"""wrapper for sql methods""" """wrapper for sql methods"""
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
assert threading.currentThread() != self.sql_thread assert threading.currentThread() != self.sql_thread
f = concurrent.futures.Future() f = asyncio.Future()
self.db_requests.put((f, func, args, kwargs)) self.db_requests.put((f, func, args, kwargs))
return f.result(timeout=10) return f
return wrapper return wrapper
class SqlDB(Logger): class SqlDB(Logger):
def __init__(self, network, path, base): def __init__(self, network, path, base, commit_interval=None):
Logger.__init__(self) Logger.__init__(self)
self.base = base self.base = base
self.network = network self.network = network
self.path = path self.path = path
self.commit_interval = commit_interval
self.db_requests = queue.Queue() self.db_requests = queue.Queue()
self.sql_thread = threading.Thread(target=self.run_sql) self.sql_thread = threading.Thread(target=self.run_sql)
self.sql_thread.start() self.sql_thread.start()
def run_sql(self): def run_sql(self):
#return
self.logger.info("SQL thread started")
engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True) engine = create_engine('sqlite:///' + self.path, pool_reset_on_return=None, poolclass=StaticPool)#, echo=True)
DBSession = sessionmaker(bind=engine, autoflush=False) DBSession = sessionmaker(bind=engine, autoflush=False)
self.DBSession = DBSession()
if not os.path.exists(self.path): if not os.path.exists(self.path):
self.base.metadata.create_all(engine) self.base.metadata.create_all(engine)
self.DBSession = DBSession()
i = 0
while self.network.asyncio_loop.is_running(): while self.network.asyncio_loop.is_running():
try: try:
future, func, args, kwargs = self.db_requests.get(timeout=0.1) future, func, args, kwargs = self.db_requests.get(timeout=0.1)
@ -50,7 +55,14 @@ class SqlDB(Logger):
except BaseException as e: except BaseException as e:
future.set_exception(e) future.set_exception(e)
continue continue
future.set_result(result) if not future.cancelled():
future.set_result(result)
# note: in sweepstore session.commit() is called inside
# the sql-decorated methods, so commiting to disk is awaited
if self.commit_interval:
i = (i + 1) % self.commit_interval
if i == 0:
self.DBSession.commit()
# write # write
self.DBSession.commit() self.DBSession.commit()
self.logger.info("SQL thread terminated") self.logger.info("SQL thread terminated")

3
electrum/tests/test_lnpeer.py

@ -16,7 +16,8 @@ from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
from electrum.lnutil import PaymentFailure, LnLocalFeatures from electrum.lnutil import PaymentFailure, LnLocalFeatures
from electrum.lnrouter import ChannelDB, LNPathFinder from electrum.lnrouter import LNPathFinder
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet from electrum.lnworker import LNWallet
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum.logging import console_stderr_handler from electrum.logging import console_stderr_handler

12
electrum/tests/test_lnrouter.py

@ -59,33 +59,33 @@ class Test_LNRouter(TestCaseForTestnet):
cdb = fake_network.channel_db cdb = fake_network.channel_db
path_finder = lnrouter.LNPathFinder(cdb) path_finder = lnrouter.LNPathFinder(cdb)
self.assertEqual(cdb.num_channels, 0) self.assertEqual(cdb.num_channels, 0)
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc', cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02cccccccccccccccccccccccccccccccc',
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02cccccccccccccccccccccccccccccccc',
'short_channel_id': bfh('0000000000000001'), 'short_channel_id': bfh('0000000000000001'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True) 'len': b'\x00\x00', 'features': b''}, trusted=True)
self.assertEqual(cdb.num_channels, 1) self.assertEqual(cdb.num_channels, 1)
cdb.on_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', cdb.add_channel_announcement({'node_id_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'bitcoin_key_1': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'short_channel_id': bfh('0000000000000002'), 'short_channel_id': bfh('0000000000000002'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True) 'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb', 'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb',
'short_channel_id': bfh('0000000000000003'), 'short_channel_id': bfh('0000000000000003'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True) 'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', cdb.add_channel_announcement({'node_id_1': b'\x02cccccccccccccccccccccccccccccccc', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_1': b'\x02cccccccccccccccccccccccccccccccc', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
'short_channel_id': bfh('0000000000000004'), 'short_channel_id': bfh('0000000000000004'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True) 'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', cdb.add_channel_announcement({'node_id_1': b'\x02dddddddddddddddddddddddddddddddd', 'node_id_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee', 'bitcoin_key_1': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_2': b'\x02eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee',
'short_channel_id': bfh('0000000000000005'), 'short_channel_id': bfh('0000000000000005'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),
'len': b'\x00\x00', 'features': b''}, trusted=True) 'len': b'\x00\x00', 'features': b''}, trusted=True)
cdb.on_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd', cdb.add_channel_announcement({'node_id_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'node_id_2': b'\x02dddddddddddddddddddddddddddddddd',
'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd', 'bitcoin_key_1': b'\x02aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', 'bitcoin_key_2': b'\x02dddddddddddddddddddddddddddddddd',
'short_channel_id': bfh('0000000000000006'), 'short_channel_id': bfh('0000000000000006'),
'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'), 'chain_hash': bfh('43497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000'),

Loading…
Cancel
Save