From 9b1c40e3964ba7036600cbb0f8bd5938cb3edb22 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Sun, 22 Oct 2023 12:49:26 +0200 Subject: [PATCH] Refactor payment forwarding: - all forwarding types use the same flow - forwarding callback returns a htlc_key or None - forwarding info is persisted in lnworker: - ongoing_forwardings - downstream to upstream htlc_key - htlc_key -> error_bytes --- electrum/lnchannel.py | 19 --- electrum/lnpeer.py | 225 +++++++++++++++------------------- electrum/lnutil.py | 7 ++ electrum/lnworker.py | 86 ++++++++----- electrum/submarine_swaps.py | 4 +- electrum/tests/test_lnpeer.py | 14 ++- 6 files changed, 180 insertions(+), 175 deletions(-) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 22f0e4f76..fe0ec0d22 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -657,7 +657,6 @@ class Channel(AbstractChannel): self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) - self.fail_htlc_reasons = state["fail_htlc_reasons"] self.unfulfilled_htlcs = state["unfulfilled_htlcs"] self._state = ChannelState[state['state']] self.peer_state = PeerState.DISCONNECTED @@ -1222,26 +1221,8 @@ class Channel(AbstractChannel): error_bytes, failure_message = self._receive_fail_reasons.pop(htlc.htlc_id) except KeyError: error_bytes, failure_message = None, None - # if we are forwarding, save error message to disk - if self.lnworker.get_payment_info(htlc.payment_hash) is None: - self.save_fail_htlc_reason(htlc.htlc_id, error_bytes, failure_message) self.lnworker.htlc_failed(self, htlc.payment_hash, htlc.htlc_id, error_bytes, failure_message) - def save_fail_htlc_reason( - self, - htlc_id: int, - error_bytes: Optional[bytes], - failure_message: Optional['OnionRoutingFailure']): - error_hex = error_bytes.hex() if error_bytes else None - failure_hex = failure_message.to_bytes().hex() if failure_message else None - self.fail_htlc_reasons[htlc_id] = (error_hex, failure_hex) - - def pop_fail_htlc_reason(self, htlc_id): - error_hex, failure_hex = self.fail_htlc_reasons.pop(htlc_id, (None, None)) - error_bytes = bytes.fromhex(error_hex) if error_hex else None - failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None - return error_bytes, failure_message - def extract_preimage_from_htlc_txin(self, txin: TxInput) -> None: witness = txin.witness_elements() if len(witness) == 5: # HTLC success tx diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 70ab8d3b3..7b1cd1acf 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -45,6 +45,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, ChannelType, LNProtocolWarning, validate_features, IncompatibleOrInsaneFeatures) from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget +from .lnutil import serialize_htlc_key from .lntransport import LNTransport, LNTransportBase from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .interface import GracefulDisconnect @@ -120,7 +121,6 @@ class Peer(Logger): self._received_revack_event = asyncio.Event() self.received_commitsig_event = asyncio.Event() self.downstream_htlc_resolved_event = asyncio.Event() - self.jit_failures = {} def send_message(self, message_name: str, **kwargs): assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" @@ -1688,7 +1688,8 @@ class Peer(Logger): chan.receive_htlc(htlc, onion_packet) util.trigger_callback('htlc_added', chan, htlc, RECEIVED) - def maybe_forward_htlc( + + async def maybe_forward_htlc( self, *, incoming_chan: Channel, htlc: UpdateAddHtlc, @@ -1742,20 +1743,12 @@ class Peer(Logger): if next_chan.can_pay(next_amount_msat_htlc): break else: - async def wrapped_callback(): - coro = self.lnworker.open_channel_just_in_time( - next_peer, - next_amount_msat_htlc, - next_cltv_abs, - htlc.payment_hash, - processed_onion.next_packet) - try: - await coro - except OnionRoutingFailure as e: - self.jit_failures[next_chan_scid.hex()] = e - - asyncio.ensure_future(wrapped_callback()) - return next_chan_scid, -1 + return await self.lnworker.open_channel_just_in_time( + next_peer, + next_amount_msat_htlc, + next_cltv_abs, + htlc.payment_hash, + processed_onion.next_packet) local_height = chain.height() if next_chan is None: @@ -1812,7 +1805,8 @@ class Peer(Logger): except BaseException as e: log_fail_reason(f"error sending message to next_peer={next_chan.node_id.hex()}") raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) - return next_chan_scid, next_htlc.htlc_id + htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id) + return htlc_key @log_exceptions async def maybe_forward_trampoline( @@ -1961,10 +1955,25 @@ class Peer(Logger): htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, onion_packet_bytes: bytes, + already_forwarded = False, ) -> Tuple[Optional[bytes], Optional[Callable]]: - """As a final recipient of an HTLC, decide if we should fulfill it. - Return (preimage, forwarding_callback) with at most a single element not None """ + Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. + Return (payment_key, preimage, callback) with at most a single element of the last two not None + Side effect: populates lnworker.received_mpp (which is not persisted, needs to be re-populated after restart) + """ + + if not processed_onion.are_we_final: + if not self.lnworker.enable_htlc_forwarding: + return None, None, None + # use the htlc key if we are forwarding + payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) + callback = lambda: self.maybe_forward_htlc( + incoming_chan=chan, + htlc=htlc, + processed_onion=processed_onion) + return payment_key, None, callback + def log_fail_reason(reason: str): self.logger.info( f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " @@ -1986,9 +1995,6 @@ class Peer(Logger): exc_incorrect_or_unknown_pd = OnionRoutingFailure( code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=amt_to_forward.to_bytes(8, byteorder="big") + local_height.to_bytes(4, byteorder="big")) - if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs: - log_fail_reason(f"htlc.cltv_abs is unreasonably close") - raise exc_incorrect_or_unknown_pd try: cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] except Exception: @@ -2025,15 +2031,19 @@ class Peer(Logger): log_fail_reason(f"'payment_secret' missing from onion") raise exc_incorrect_or_unknown_pd + # payment key for final onions + payment_hash = htlc.payment_hash + payment_key = (payment_hash + payment_secret_from_onion).hex() + from .lnworker import RecvMPPResolution mpp_resolution = self.lnworker.check_mpp_status( payment_secret=payment_secret_from_onion, - short_channel_id=chan.short_channel_id, + short_channel_id=chan.get_scid_or_local_alias(), htlc=htlc, expected_msat=total_msat, ) if mpp_resolution == RecvMPPResolution.WAITING: - return None, None + return payment_key, None, None elif mpp_resolution == RecvMPPResolution.EXPIRED: log_fail_reason(f"MPP_TIMEOUT") raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') @@ -2045,7 +2055,12 @@ class Peer(Logger): else: raise Exception(f"unexpected {mpp_resolution=}") - payment_hash = htlc.payment_hash + if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs: + if not already_forwarded: + log_fail_reason(f"htlc.cltv_abs is unreasonably close") + raise exc_incorrect_or_unknown_pd + else: + return payment_key, None, None # detect callback # if there is a trampoline_onion, maybe_fulfill_htlc will be called again @@ -2060,23 +2075,19 @@ class Peer(Logger): is_trampoline=True) if trampoline_onion.are_we_final: # trampoline- we are final recipient of HTLC - preimage, cb = self.maybe_fulfill_htlc( + return self.maybe_fulfill_htlc( chan=chan, htlc=htlc, processed_onion=trampoline_onion, onion_packet_bytes=onion_packet_bytes, ) - if preimage: - return preimage, None - else: - return None, cb else: callback = lambda: self.maybe_forward_trampoline( payment_hash=payment_hash, inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts outer_onion=processed_onion, trampoline_onion=trampoline_onion) - return None, callback + return payment_key, None, callback # TODO don't accept payments twice for same invoice # TODO check invoice expiry @@ -2102,15 +2113,19 @@ class Peer(Logger): hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) if hold_invoice_callback and not preimage: - return None, lambda: hold_invoice_callback(payment_hash) + callback = lambda: hold_invoice_callback(payment_hash) + return payment_key, None, callback if not preimage: - self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}") - raise exc_incorrect_or_unknown_pd + if not already_forwarded: + log_fail_reason(f"missing preimage and no hold invoice callback {payment_hash.hex()}") + raise exc_incorrect_or_unknown_pd + else: + return payment_key, None, None chan.opening_fee = None self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") - return preimage, None + return payment_key, preimage, None def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") @@ -2512,17 +2527,13 @@ class Peer(Logger): self.maybe_send_commitment(chan) done = set() unfulfilled = chan.unfulfilled_htlcs - for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): - if forwarding_info: - forwarding_info = tuple(forwarding_info) # storage converts to list - self.lnworker.downstream_htlc_to_upstream_peer_map[forwarding_info] = self.pubkey + for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_key) in unfulfilled.items(): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) error_reason = None # type: Optional[OnionRoutingFailure] error_bytes = None # type: Optional[bytes] preimage = None - fw_info = None onion_packet_bytes = bytes.fromhex(onion_packet_hex) onion_packet = None try: @@ -2531,20 +2542,21 @@ class Peer(Logger): error_reason = e else: try: - preimage, fw_info, error_bytes = self.process_unfulfilled_htlc( + preimage, _forwarding_key, error_bytes = self.process_unfulfilled_htlc( chan=chan, htlc=htlc, - forwarding_info=forwarding_info, + forwarding_key=forwarding_key, onion_packet_bytes=onion_packet_bytes, onion_packet=onion_packet) + if _forwarding_key: + assert forwarding_key is False + unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, _forwarding_key except OnionRoutingFailure as e: error_bytes = construct_onion_error(e, onion_packet.public_key, our_onion_private_key=self.privkey) if error_bytes: error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) - if fw_info: - unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info - self.lnworker.downstream_htlc_to_upstream_peer_map[fw_info] = self.pubkey - elif preimage or error_reason or error_bytes: + + if preimage or error_reason or error_bytes: if preimage: self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) if not self.lnworker.enable_htlc_settle: @@ -2563,10 +2575,7 @@ class Peer(Logger): done.add(htlc_id) # cleanup for htlc_id in done: - local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id) - if forwarding_info: - forwarding_info = tuple(forwarding_info) # storage converts to list - self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None) + unfulfilled.pop(htlc_id) self.maybe_send_commitment(chan) def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: @@ -2596,11 +2605,11 @@ class Peer(Logger): self, *, chan: Channel, htlc: UpdateAddHtlc, - forwarding_info: Tuple[str, int], + forwarding_key: str, onion_packet_bytes: bytes, onion_packet: OnionPacket) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]: """ - return (preimage, fw_info, error_bytes) with at most a single element that is not None + return (preimage, payment_key, error_bytes) with at most a single element that is not None raise an OnionRoutingFailure if we need to fail the htlc """ payment_hash = htlc.payment_hash @@ -2608,80 +2617,50 @@ class Peer(Logger): onion_packet, payment_hash=payment_hash, onion_packet_bytes=onion_packet_bytes) - if processed_onion.are_we_final: - # either we are final recipient; or if trampoline, see cases below - if not forwarding_info: - preimage, forwarding_callback = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=processed_onion, - onion_packet_bytes=onion_packet_bytes) - if forwarding_callback: - payment_secret = processed_onion.hop_data.payload["payment_data"]["payment_secret"] - payment_key = payment_hash + payment_secret - # trampoline- HTLC we are supposed to forward, but haven't forwarded yet - if not self.lnworker.enable_htlc_forwarding: - return None, None, None - elif payment_key in self.lnworker.final_onion_forwardings: - # we are already forwarding this payment - self.logger.info(f"we are already forwarding this.") - else: - # add to list of ongoing payments - self.lnworker.final_onion_forwardings.add(payment_key) - # clear previous failures - self.lnworker.final_onion_forwarding_failures.pop(payment_key, None) - async def wrapped_callback(): - forwarding_coro = forwarding_callback() - try: - await forwarding_coro - except OnionRoutingFailure as e: - self.lnworker.final_onion_forwarding_failures[payment_key] = e - finally: - # remove from list of payments, so that another attempt can be initiated - self.lnworker.final_onion_forwardings.remove(payment_key) - asyncio.ensure_future(wrapped_callback()) - # return fw_info so that maybe_fulfill_htlc will not be called again - fw_info = (payment_key.hex(), -1) - return None, fw_info, None - else: - # trampoline- HTLC we are supposed to forward, and have already forwarded - payment_key_outer_onion = bytes.fromhex(forwarding_info[0]) - preimage = self.lnworker.get_preimage(payment_hash) - payment_secret_inner_onion = self.lnworker.get_payment_secret(payment_hash) - payment_key_inner_onion = payment_hash + payment_secret_inner_onion - for payment_key in [payment_key_inner_onion, payment_key_outer_onion]: - # get (and not pop) failure because the incoming payment might be multi-part - error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key) - if error_reason: - self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}') - raise error_reason - - elif not forwarding_info: - # HTLC we are supposed to forward, but haven't forwarded yet - if not self.lnworker.enable_htlc_forwarding: - return None, None, None - next_chan_id, next_htlc_id = self.maybe_forward_htlc( - incoming_chan=chan, - htlc=htlc, - processed_onion=processed_onion) - fw_info = (next_chan_id.hex(), next_htlc_id) - return None, fw_info, None + + htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id) + error_bytes = error_reason = None + + # fixme: do we need the outer key? + payment_key, preimage, forwarding_callback = self.maybe_fulfill_htlc( + chan=chan, + htlc=htlc, + processed_onion=processed_onion, + onion_packet_bytes=onion_packet_bytes, + already_forwarded=bool(forwarding_key)) + + if not forwarding_key: + if forwarding_callback: + # HTLC we are supposed to forward, but haven't forwarded yet + if not self.lnworker.enable_htlc_forwarding: + return None, None, None + if payment_key not in self.lnworker.active_forwardings: + async def wrapped_callback(): + forwarding_coro = forwarding_callback() + try: + next_htlc = await forwarding_coro + if next_htlc: + self.lnworker.downstream_to_upstream_htlc[next_htlc] = htlc_key + except OnionRoutingFailure as e: + self.lnworker.save_forwarding_failure(payment_key, failure_message=e) + # add to list + self.lnworker.active_forwardings[payment_key] = True + fut = asyncio.ensure_future(wrapped_callback()) + # return payment_key so this branch will not be executed again + return None, payment_key, None else: + assert payment_key == forwarding_key # HTLC we are supposed to forward, and have already forwarded preimage = self.lnworker.get_preimage(payment_hash) - next_chan_id_hex, htlc_id = forwarding_info - next_chan = self.lnworker.get_channel_by_short_id(bytes.fromhex(next_chan_id_hex)) - if next_chan: - error_bytes, error_reason = next_chan.pop_fail_htlc_reason(htlc_id) - if error_bytes: - return None, None, error_bytes - if error_reason: - raise error_reason - # just-in-time channel - if htlc_id == -1: - error_reason = self.jit_failures.pop(next_chan_id_hex, None) - if error_reason: - raise error_reason + error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) + + if error_bytes or error_reason or preimage: + self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc) + + if error_bytes: + return None, None, error_bytes + if error_reason: + raise error_reason if preimage: return preimage, None, None return None, None, None diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 04086fefa..6e587089a 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -66,6 +66,13 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v) +def serialize_htlc_key(scid:bytes, htlc_id: int): + return scid.hex() + ':%d'%htlc_id + +def deserialize_htlc_key(htlc_key:str): + scid, htlc_id = htlc_key.split(':') + return bytes.fromhex(scid), int(htlc_id) + @attr.s class OnlyPubkeyKeypair(StoredObject): pubkey = attr.ib(type=bytes, converter=hex_to_bytes) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b3a5484b9..983aa141a 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -56,6 +56,7 @@ from .lnchannel import ChannelState, PeerState, HTLCWithStatus from .lnrater import LNRater from . import lnutil from .lnutil import funding_output_script +from .lnutil import serialize_htlc_key, deserialize_htlc_key from .bitcoin import redeem_script_to_address, DummyAddress from .lnutil import (Outpoint, LNPeerAddr, get_compressed_pubkey_from_bech32, extract_nodeid, @@ -840,10 +841,11 @@ class LNWallet(LNWorker): for payment_hash in self.get_payments(status='inflight').keys(): self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) - self.final_onion_forwardings = set() - self.final_onion_forwarding_failures = {} # todo: should be persisted - # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys - self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] + # payment forwarding + self.active_forwardings = self.db.get_dict('active_forwardings') # list of payment_keys + self.forwarding_failures = self.db.get_dict('forwarding_failures') # payment_key -> (error_bytes, error_message) + self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted) + # payment_hash -> callback: self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]] self.payment_bundles = [] # lists of hashes. todo:persist @@ -1275,9 +1277,10 @@ class LNWallet(LNWorker): except Exception: raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') # We have been paid and can broadcast - # if broadcasting raise an exception, we should try to rebroadcast + # todo: if broadcasting raise an exception, we should try to rebroadcast await self.network.broadcast_transaction(funding_tx) - return next_chan, funding_tx + htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id) + return htlc_key @log_exceptions async def open_channel_with_peer( @@ -2267,7 +2270,6 @@ class LNWallet(LNWorker): if pkey in self.received_mpp_htlcs: self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) - self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc) return mpp_resolution def update_mpp_with_received_htlc( @@ -2312,20 +2314,31 @@ class LNWallet(LNWorker): return int(time.time()) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) - def maybe_cleanup_mpp_status( - self, - payment_key: bytes, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, + def maybe_cleanup_forwarding( + self, + payment_key_hex: str, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, ) -> None: - mpp_status = self.received_mpp_htlcs[payment_key] - if mpp_status.resolution == RecvMPPResolution.WAITING: - return - key = (short_channel_id, htlc) - mpp_status.htlc_set.remove(key) # side-effecting htlc_set - if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs: + + is_htlc_key = ':' in payment_key_hex + if not is_htlc_key: + payment_key = bytes.fromhex(payment_key_hex) + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == RecvMPPResolution.WAITING: + # reconstructing the MPP after restart + self.logger.info(f'cannot cleanup mpp, still waiting') + return + htlc_key = (short_channel_id, htlc) + mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set + if mpp_status.htlc_set: + return + self.logger.info('cleaning up mpp') self.received_mpp_htlcs.pop(payment_key) + self.active_forwardings.pop(payment_key_hex, None) + self.forwarding_failures.pop(payment_key_hex, None) + def get_payment_status(self, payment_hash: bytes) -> int: info = self.get_payment_info(payment_hash) return info.status if info else PR_UNPAID @@ -2370,16 +2383,23 @@ class LNWallet(LNWorker): info = info._replace(status=status) self.save_payment_info(info) - def is_forwarded_htlc_notify(self, chan: Channel, htlc_id: int) -> bool: + def is_forwarded_htlc_notify( + self, chan: Channel, htlc_id: int, *, + error_bytes: Optional[bytes] = None, + failure_message: Optional['OnionRoutingFailure'] = None + ) -> bool: """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. If we find this was a forwarded HTLC, the upstream peer is notified. Returns whether this was a forwarded HTLC. """ - fw_info = chan.get_scid_or_local_alias().hex(), htlc_id - upstream_peer_pubkey = self.downstream_htlc_to_upstream_peer_map.get(fw_info) - if not upstream_peer_pubkey: + htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id) + upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None) + if not upstream_key: return False - upstream_peer = self.peers.get(upstream_peer_pubkey) + self.save_forwarding_failure(upstream_key, error_bytes=error_bytes, failure_message=failure_message) + upstream_chan_scid, _ = deserialize_htlc_key(upstream_key) + upstream_chan = self.get_channel_by_short_id(upstream_chan_scid) + upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None if upstream_peer: upstream_peer.downstream_htlc_resolved_event.set() upstream_peer.downstream_htlc_resolved_event.clear() @@ -2416,7 +2436,7 @@ class LNWallet(LNWorker): failure_message: Optional['OnionRoutingFailure']): util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) - if self.is_forwarded_htlc_notify(chan=chan, htlc_id=htlc_id): + if self.is_forwarded_htlc_notify(chan=chan, htlc_id=htlc_id, error_bytes=error_bytes, failure_message=failure_message): return if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): onion_key = chan.pop_onion_key(htlc_id) @@ -3026,7 +3046,17 @@ class LNWallet(LNWorker): util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) - def fail_final_onion_forwarding(self, payment_key): - """ use this to fail htlcs received for hold invoices""" - e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') - self.final_onion_forwarding_failures[payment_key] = e + def save_forwarding_failure( + self, payment_key:str, *, + error_bytes: Optional[bytes] = None, + failure_message: Optional['OnionRoutingFailure'] = None): + error_hex = error_bytes.hex() if error_bytes else None + failure_hex = failure_message.to_bytes().hex() if failure_message else None + self.forwarding_failures[payment_key] = (error_hex, failure_hex) + + def get_forwarding_failure(self, payment_key: str): + error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None)) + error_bytes = bytes.fromhex(error_hex) if error_hex else None + failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None + return error_bytes, failure_message + diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 35dfff4b6..e23da0660 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -30,6 +30,7 @@ from .bitcoin import construct_script from .crypto import ripemd from .invoices import Invoice from .network import TxBroadcastServerReturnedError +from .lnonion import OnionRoutingFailure, OnionFailureCode if TYPE_CHECKING: @@ -234,7 +235,8 @@ class SwapManager(Logger): self.lnworker.unregister_hold_invoice(swap.payment_hash) payment_secret = self.lnworker.get_payment_secret(swap.payment_hash) payment_key = swap.payment_hash + payment_secret - self.lnworker.fail_final_onion_forwarding(payment_key) + e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') + self.lnworker.save_forwarding_failure(payment_key.hex(), failure_message=e) self.lnwatcher.remove_callback(swap.lockup_address) if swap.funding_txid is None: self.swaps.pop(swap.payment_hash.hex()) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 2fa1a5c52..3328bf5fd 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -171,12 +171,12 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self._paysessions = dict() self.sent_htlcs_info = dict() self.sent_buckets = defaultdict(set) - self.final_onion_forwardings = set() - self.final_onion_forwarding_failures = {} + self.active_forwardings = {} + self.forwarding_failures = {} self.inflight_payments = set() self.preimages = {} self.stopping_soon = False - self.downstream_htlc_to_upstream_peer_map = {} + self.downstream_to_upstream_htlc = {} self.hold_invoice_callbacks = {} self.payment_bundles = [] # lists of hashes. todo:persist self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 @@ -298,10 +298,12 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): set_mpp_resolution = LNWallet.set_mpp_resolution is_mpp_amount_reached = LNWallet.is_mpp_amount_reached get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp - maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status bundle_payments = LNWallet.bundle_payments get_payment_bundle = LNWallet.get_payment_bundle _get_payment_key = LNWallet._get_payment_key + save_forwarding_failure = LNWallet.save_forwarding_failure + get_forwarding_failure = LNWallet.get_forwarding_failure + maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding class MockTransport: @@ -1821,6 +1823,8 @@ class TestPeerForwarding(TestPeer): test_failure=False, attempts=2): + bob_w = graph.workers['bob'] + carol_w = graph.workers['carol'] dave_w = graph.workers['dave'] async def pay(lnaddr, pay_req): @@ -1828,6 +1832,8 @@ class TestPeerForwarding(TestPeer): result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) if result: self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertFalse(bool(bob_w.active_forwardings)) + self.assertFalse(bool(carol_w.active_forwardings)) raise PaymentDone() else: raise NoPathFound()