diff --git a/electrum/channel_db.py b/electrum/channel_db.py index e557261bf..9eb9abd61 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -307,6 +307,8 @@ PRIMARY KEY(node_id) class ChannelDB(SqlDB): NUM_MAX_RECENT_PEERS = 20 + PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL = 600 + PRIVATE_CHAN_UPD_CACHE_TTL_SHORT = 120 def __init__(self, network: 'Network'): path = self.get_file_path(network.config) @@ -314,7 +316,10 @@ class ChannelDB(SqlDB): self.lock = threading.RLock() self.num_nodes = 0 self.num_channels = 0 - self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] + self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], Tuple[dict, int]] + # note: ^ we could maybe move this cache into PaySession instead of being global. + # That would only make sense though if PaySessions were never too short + # (e.g. consider trampoline forwarding). self.ca_verifier = LNChannelVerifier(network, self) # initialized in load_data @@ -667,19 +672,46 @@ class ChannelDB(SqlDB): self.update_counts() self.logger.info(f'Deleting {len(orphaned_chans)} orphaned channels') - def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes) -> bool: + def _get_channel_update_for_private_channel( + self, + start_node_id: bytes, + short_channel_id: ShortChannelID, + *, + now: int = None, # unix ts + ) -> Optional[dict]: + if now is None: + now = int(time.time()) + key = (start_node_id, short_channel_id) + chan_upd_dict, cache_expiration = self._channel_updates_for_private_channels.get(key, (None, 0)) + if cache_expiration < now: + chan_upd_dict = None # already expired + # TODO rm expired entries from cache (note: perf vs thread-safety) + return chan_upd_dict + + def add_channel_update_for_private_channel( + self, + msg_payload: dict, + start_node_id: bytes, + *, + cache_ttl: int = None, # seconds + ) -> bool: """Returns True iff the channel update was successfully added and it was different than what we had before (if any). """ if not verify_sig_for_channel_update(msg_payload, start_node_id): return False # ignore + now = int(time.time()) short_channel_id = ShortChannelID(msg_payload['short_channel_id']) msg_payload['start_node'] = start_node_id - key = (start_node_id, short_channel_id) - prev_chanupd = self._channel_updates_for_private_channels.get(key) + prev_chanupd = self._get_channel_update_for_private_channel(start_node_id, short_channel_id, now=now) if prev_chanupd == msg_payload: return False - self._channel_updates_for_private_channels[key] = msg_payload + if cache_ttl is None: + cache_ttl = self.PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL + cache_expiration = now + cache_ttl + key = (start_node_id, short_channel_id) + with self.lock: + self._channel_updates_for_private_channels[key] = msg_payload, cache_expiration return True def remove_channel(self, short_channel_id: ShortChannelID): @@ -791,21 +823,20 @@ class ChannelDB(SqlDB): def get_policy_for_node( self, - short_channel_id: bytes, + short_channel_id: ShortChannelID, node_id: bytes, *, my_channels: Dict[ShortChannelID, 'Channel'] = None, private_route_edges: Dict[ShortChannelID, 'RouteEdge'] = None, + now: int = None, # unix ts ) -> Optional['Policy']: channel_info = self.get_channel_info(short_channel_id) if channel_info is not None: # publicly announced channel policy = self._policies.get((node_id, short_channel_id)) if policy: return policy - else: # private channel - chan_upd_dict = self._channel_updates_for_private_channels.get((node_id, short_channel_id)) - if chan_upd_dict: - return Policy.from_msg(chan_upd_dict) + elif chan_upd_dict := self._get_channel_update_for_private_channel(node_id, short_channel_id, now=now): + return Policy.from_msg(chan_upd_dict) # check if it's one of our own channels if my_channels: policy = get_mychannel_policy(short_channel_id, node_id, my_channels) diff --git a/electrum/lnonion.py b/electrum/lnonion.py index c7cd9d44e..5e00c151c 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -25,7 +25,7 @@ import io import hashlib -from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional +from typing import Sequence, List, Tuple, NamedTuple, TYPE_CHECKING, Dict, Any, Optional, Union from enum import IntEnum from . import ecc @@ -356,7 +356,7 @@ class FailedToDecodeOnionError(Exception): pass class OnionRoutingFailure(Exception): - def __init__(self, code: int, data: bytes): + def __init__(self, code: Union[int, 'OnionFailureCode'], data: bytes): self.code = code self.data = data diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index b2872ff3b..75a195e1e 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -407,6 +407,7 @@ class LNPathFinder(Logger): return False if blacklist_expiration < now: return False + # TODO rm expired entries from cache (note: perf vs thread-safety) return True def add_edge_to_blacklist( @@ -480,12 +481,12 @@ class LNPathFinder(Logger): if channel_info is None: return float('inf'), 0 channel_policy = self.channel_db.get_policy_for_node( - short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges) + short_channel_id, start_node, my_channels=my_channels, private_route_edges=private_route_edges, now=now) if channel_policy is None: return float('inf'), 0 # channels that did not publish both policies often return temporary channel failure channel_policy_backwards = self.channel_db.get_policy_for_node( - short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges) + short_channel_id, end_node, my_channels=my_channels, private_route_edges=private_route_edges, now=now) if (channel_policy_backwards is None and not is_mine and short_channel_id not in private_route_edges): diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 48830d160..180609340 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1589,7 +1589,7 @@ class LNWallet(LNWorker): else: # apply the channel update or get blacklisted blacklist, update = self._handle_chanupd_from_failed_htlc( - payload, route=route, sender_idx=sender_idx) + payload, route=route, sender_idx=sender_idx, failure_msg=failure_msg) # we interpret a temporary channel failure as a liquidity issue # in the channel and update our liquidity hints accordingly if code == OnionFailureCode.TEMPORARY_CHANNEL_FAILURE: @@ -1606,7 +1606,12 @@ class LNWallet(LNWorker): if blacklist: self.network.path_finder.add_edge_to_blacklist(short_channel_id=failing_channel) - def _handle_chanupd_from_failed_htlc(self, payload, *, route, sender_idx) -> Tuple[bool, bool]: + def _handle_chanupd_from_failed_htlc( + self, payload, *, + route: LNPaymentRoute, + sender_idx: int, + failure_msg: OnionRoutingFailure, + ) -> Tuple[bool, bool]: blacklist = False update = False try: @@ -1626,7 +1631,12 @@ class LNWallet(LNWorker): # maybe it is a private channel (and data in invoice was outdated) self.logger.info(f"Could not find {short_channel_id}. maybe update is for private channel?") start_node_id = route[sender_idx].node_id - update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id) + cache_ttl = None + if failure_msg.code == OnionFailureCode.CHANNEL_DISABLED: + # eclair sends CHANNEL_DISABLED if its peer is offline. E.g. we might be trying to pay + # a mobile phone with the app closed. So we cache this with a short TTL. + cache_ttl = self.channel_db.PRIVATE_CHAN_UPD_CACHE_TTL_SHORT + update = self.channel_db.add_channel_update_for_private_channel(payload, start_node_id, cache_ttl=cache_ttl) blacklist = not update elif r == UpdateStatus.EXPIRED: blacklist = True