Browse Source

channel_db: add TTLs to channel_updates_for_private_channels

eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be
trying to pay a mobile phone with the app closed. In that case we
should not cache the CHANNEL_DISABLED for too long.
master
SomberNight 2 years ago
parent
commit
98ccad68c1
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 51
      electrum/channel_db.py
  2. 4
      electrum/lnonion.py
  3. 5
      electrum/lnrouter.py
  4. 16
      electrum/lnworker.py

51
electrum/channel_db.py

@ -307,6 +307,8 @@ PRIMARY KEY(node_id)
class ChannelDB(SqlDB): class ChannelDB(SqlDB):
NUM_MAX_RECENT_PEERS = 20 NUM_MAX_RECENT_PEERS = 20
PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL = 600
PRIVATE_CHAN_UPD_CACHE_TTL_SHORT = 120
def __init__(self, network: 'Network'): def __init__(self, network: 'Network'):
path = self.get_file_path(network.config) path = self.get_file_path(network.config)
@ -314,7 +316,10 @@ class ChannelDB(SqlDB):
self.lock = threading.RLock() self.lock = threading.RLock()
self.num_nodes = 0 self.num_nodes = 0
self.num_channels = 0 self.num_channels = 0
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], Tuple[dict, int]]
# note: ^ we could maybe move this cache into PaySession instead of being global.
# That would only make sense though if PaySessions were never too short
# (e.g. consider trampoline forwarding).
self.ca_verifier = LNChannelVerifier(network, self) self.ca_verifier = LNChannelVerifier(network, self)
# initialized in load_data # initialized in load_data
@ -667,19 +672,46 @@ class ChannelDB(SqlDB):
self.update_counts() self.update_counts()
self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels')
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool: def _get_channel_update_for_private_channel(
self,
start_node_id: bytes,
short_channel_id: ShortChannelID,
*,
now: int = None, # unix ts
) -> Optional[dict]:
if now is None:
now = int(time.time())
key = (start_node_id, short_channel_id)
chan_upd_dict, cache_expiration = self._channel_updates_for_private_channels.get(key, (None, 0))
if cache_expiration < now:
chan_upd_dict = None # already expired
# TODO rm expired entries from cache (note: perf vs thread-safety)
return chan_upd_dict
def add_channel_update_for_private_channel(
self,
msg_payload: dict,
start_node_id: bytes,
*,
cache_ttl: int = None, # seconds
) -> bool:
"""Returns True iff the channel update was successfully added and it was different than """Returns True iff the channel update was successfully added and it was different than
what we had before (if any). what we had before (if any).
""" """
if not verify_sig_for_channel_update(msg_payload, start_node_id): if not verify_sig_for_channel_update(msg_payload, start_node_id):
return False # ignore return False # ignore
now = int(time.time())
short_channel_id = ShortChannelID(msg_payload['short_channel_id']) short_channel_id = ShortChannelID(msg_payload['short_channel_id'])
msg_payload['start_node'] = start_node_id msg_payload['start_node'] = start_node_id
key = (start_node_id, short_channel_id) prev_chanupd = self._get_channel_update_for_private_channel(start_node_id, short_channel_id, now=now)
prev_chanupd = self._channel_updates_for_private_channels.get(key)
if prev_chanupd == msg_payload: if prev_chanupd == msg_payload:
return False return False
self._channel_updates_for_private_channels[key] = msg_payload if cache_ttl is None:
cache_ttl = self.PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL
cache_expiration = now + cache_ttl
key = (start_node_id, short_channel_id)
with self.lock:
self._channel_updates_for_private_channels[key] = msg_payload, cache_expiration
return True return True
def remove_channel(self, short_channel_id: ShortChannelID): def remove_channel(self, short_channel_id: ShortChannelID):
@ -791,21 +823,20 @@ class ChannelDB(SqlDB):
def get_policy_for_node( def get_policy_for_node(
self, self,
short_channel_id: bytes, short_channel_id: ShortChannelID,
node_id: bytes, node_id: bytes,
*, *,
my_channels: Dict[ShortChannelID, 'Channel'] = None, my_channels: Dict[ShortChannelID, 'Channel'] = None,
private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None,
now: int = None, # unix ts
) -> Optional['Policy']: ) -> Optional['Policy']:
channel_info = self.get_channel_info(short_channel_id) channel_info = self.get_channel_info(short_channel_id)
if channel_info is not None: # publicly announced channel if channel_info is not None: # publicly announced channel
policy = self._policies.get((node_id, short_channel_id)) policy = self._policies.get((node_id, short_channel_id))
if policy: if policy:
return policy return policy
else: # private channel elif chan_upd_dict := self._get_channel_update_for_private_channel(node_id, short_channel_id, now=now):
chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id)) return Policy.from_msg(chan_upd_dict)
if chan_upd_dict:
return Policy.from_msg(chan_upd_dict)
# check if it's one of our own channels # check if it's one of our own channels
if my_channels: if my_channels:
policy = get_mychannel_policy(short_channel_id, node_id, my_channels) policy = get_mychannel_policy(short_channel_id, node_id, my_channels)

4
electrum/lnonion.py

@ -25,7 +25,7 @@
import io import io
import hashlib import hashlib
from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union
from enum import IntEnum from enum import IntEnum
from . import ecc from . import ecc
@ -356,7 +356,7 @@ class FailedToDecodeOnionError(Exception): pass
class OnionRoutingFailure(Exception): class OnionRoutingFailure(Exception):
def __init__(self, code: int, data: bytes): def __init__(self, code: Union[int, 'OnionFailureCode'], data: bytes):
self.code = code self.code = code
self.data = data self.data = data

5
electrum/lnrouter.py

@ -407,6 +407,7 @@ class LNPathFinder(Logger):
return False return False
if blacklist_expiration < now: if blacklist_expiration < now:
return False return False
# TODO rm expired entries from cache (note: perf vs thread-safety)
return True return True
def add_edge_to_blacklist( def add_edge_to_blacklist(
@ -480,12 +481,12 @@ class LNPathFinder(Logger):
if channel_info is None: if channel_info is None:
return float('inf'), 0 return float('inf'), 0
channel_policy = self.channel_db.get_policy_for_node( channel_policy = self.channel_db.get_policy_for_node(
short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges) short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges, now=now)
if channel_policy is None: if channel_policy is None:
return float('inf'), 0 return float('inf'), 0
# channels that did not publish both policies often return temporary channel failure # channels that did not publish both policies often return temporary channel failure
channel_policy_backwards = self.channel_db.get_policy_for_node( channel_policy_backwards = self.channel_db.get_policy_for_node(
short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges) short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges, now=now)
if (channel_policy_backwards is None if (channel_policy_backwards is None
and not is_mine and not is_mine
and short_channel_id not in private_route_edges): and short_channel_id not in private_route_edges):

16
electrum/lnworker.py

@ -1589,7 +1589,7 @@ class LNWallet(LNWorker):
else: else:
# apply the channel update or get blacklisted # apply the channel update or get blacklisted
blacklist, update = self._handle_chanupd_from_failed_htlc( blacklist, update = self._handle_chanupd_from_failed_htlc(
payload, route=route, sender_idx=sender_idx) payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg)
# we interpret a temporary channel failure as a liquidity issue # we interpret a temporary channel failure as a liquidity issue
# in the channel and update our liquidity hints accordingly # in the channel and update our liquidity hints accordingly
if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE:
@ -1606,7 +1606,12 @@ class LNWallet(LNWorker):
if blacklist: if blacklist:
self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel) self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel)
def _handle_chanupd_from_failed_htlc(self, payload, *, route, sender_idx) -> Tuple[bool, bool]: def _handle_chanupd_from_failed_htlc(
self, payload, *,
route: LNPaymentRoute,
sender_idx: int,
failure_msg: OnionRoutingFailure,
) -> Tuple[bool, bool]:
blacklist = False blacklist = False
update = False update = False
try: try:
@ -1626,7 +1631,12 @@ class LNWallet(LNWorker):
# maybe it is a private channel (and data in invoice was outdated) # maybe it is a private channel (and data in invoice was outdated)
self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?") self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?")
start_node_id = route[sender_idx].node_id start_node_id = route[sender_idx].node_id
update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id) cache_ttl = None
if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED:
# eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay
# a mobile phone with the app closed. So we cache this with a short TTL.
cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT
update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl)
blacklist = not update blacklist = not update
elif r == UpdateStatus.EXPIRED: elif r == UpdateStatus.EXPIRED:
blacklist = True blacklist = True

Loading…
Cancel
Save