diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index e92a02bea..356a2ed1a 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict import asyncio import os import time -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable from datetime import datetime import functools @@ -1668,7 +1668,8 @@ class Peer(Logger): next_peer.maybe_send_commitment(next_chan) return next_chan_scid, next_htlc.htlc_id - def maybe_forward_trampoline( + @log_exceptions + async def maybe_forward_trampoline( self, *, payment_hash: bytes, cltv_expiry: int, @@ -1713,48 +1714,34 @@ class Peer(Logger): trampoline_fee = total_msat - amt_to_forward self.logger.info(f'trampoline cltv and fee: {trampoline_cltv_delta, trampoline_fee}') - @log_exceptions - async def forward_trampoline_payment(): - try: - await self.lnworker.pay_to_node( - node_pubkey=outgoing_node_id, - payment_hash=payment_hash, - payment_secret=payment_secret, - amount_to_pay=amt_to_forward, - min_cltv_expiry=cltv_from_onion, - r_tags=[], - invoice_features=invoice_features, - fwd_trampoline_onion=next_trampoline_onion, - fwd_trampoline_fee=trampoline_fee, - fwd_trampoline_cltv_delta=trampoline_cltv_delta, - attempts=1) - except OnionRoutingFailure as e: - # FIXME: cannot use payment_hash as key - self.lnworker.trampoline_forwarding_failures[payment_hash] = e - except PaymentFailure as e: - # FIXME: adapt the error code - error_reason = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') - self.lnworker.trampoline_forwarding_failures[payment_hash] = error_reason - - # remove from list of payments, so that another attempt can be initiated - self.lnworker.trampoline_forwardings.remove(payment_hash) - - # add to list of ongoing payments - self.lnworker.trampoline_forwardings.add(payment_hash) - # clear previous failures - self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None) - # start payment - asyncio.ensure_future(forward_trampoline_payment()) + try: + await self.lnworker.pay_to_node( + node_pubkey=outgoing_node_id, + payment_hash=payment_hash, + payment_secret=payment_secret, + amount_to_pay=amt_to_forward, + min_cltv_expiry=cltv_from_onion, + r_tags=[], + invoice_features=invoice_features, + fwd_trampoline_onion=next_trampoline_onion, + fwd_trampoline_fee=trampoline_fee, + fwd_trampoline_cltv_delta=trampoline_cltv_delta, + attempts=1) + except OnionRoutingFailure as e: + raise + except PaymentFailure as e: + # FIXME: adapt the error code + raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') def maybe_fulfill_htlc( self, *, chan: Channel, htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, - is_trampoline: bool = False) -> Optional[bytes]: - + onion_packet_bytes: bytes, + is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[Callable]]: """As a final recipient of an HTLC, decide if we should fulfill it. - Return preimage or None + Return (preimage, forwarding_callback) with at most a single element not None """ def log_fail_reason(reason: str): self.logger.info(f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " @@ -1810,19 +1797,55 @@ class Peer(Logger): log_fail_reason(f"'payment_secret' missing from onion") raise exc_incorrect_or_unknown_pd - payment_status = self.lnworker.check_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) + payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) if payment_status is None: - return None + return None, None elif payment_status is False: log_fail_reason(f"MPP_TIMEOUT") raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') else: assert payment_status is True + payment_hash = htlc.payment_hash + preimage = self.lnworker.get_preimage(payment_hash) + hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) + if not preimage and hold_invoice_callback: + if preimage: + return preimage, None + else: + # for hold invoices, trigger callback + cb, timeout = hold_invoice_callback + if int(time.time()) < timeout: + return None, lambda: cb(payment_hash) + else: + raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + # if there is a trampoline_onion, maybe_fulfill_htlc will be called again if processed_onion.trampoline_onion_packet: # TODO: we should check that all trampoline_onions are the same - return None + + trampoline_onion = self.process_onion_packet( + processed_onion.trampoline_onion_packet, + payment_hash=payment_hash, + onion_packet_bytes=onion_packet_bytes, + is_trampoline=True) + if trampoline_onion.are_we_final: + # trampoline- we are final recipient of HTLC + preimage, cb = self.maybe_fulfill_htlc( + chan=chan, + htlc=htlc, + processed_onion=trampoline_onion, + onion_packet_bytes=onion_packet_bytes, + is_trampoline=True) + assert cb is None + return preimage, None + else: + callback = lambda: self.maybe_forward_trampoline( + payment_hash=payment_hash, + cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts + outer_onion=processed_onion, + trampoline_onion=trampoline_onion) + return None, callback # TODO don't accept payments twice for same invoice # TODO check invoice expiry @@ -1845,7 +1868,7 @@ class Peer(Logger): if preimage: self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) - return preimage + return preimage, None def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") @@ -2340,42 +2363,36 @@ class Peer(Logger): onion_packet_bytes=onion_packet_bytes) if processed_onion.are_we_final: # either we are final recipient; or if trampoline, see cases below - preimage = self.maybe_fulfill_htlc( + preimage, forwarding_callback = self.maybe_fulfill_htlc( chan=chan, htlc=htlc, - processed_onion=processed_onion) + processed_onion=processed_onion, + onion_packet_bytes=onion_packet_bytes) - if processed_onion.trampoline_onion_packet: - # trampoline- recipient or forwarding + if forwarding_callback: if not forwarding_info: - trampoline_onion = self.process_onion_packet( - processed_onion.trampoline_onion_packet, - payment_hash=payment_hash, - onion_packet_bytes=onion_packet_bytes, - is_trampoline=True) - if trampoline_onion.are_we_final: - # trampoline- we are final recipient of HTLC - preimage = self.maybe_fulfill_htlc( - chan=chan, - htlc=htlc, - processed_onion=trampoline_onion, - is_trampoline=True) + # trampoline- HTLC we are supposed to forward, but haven't forwarded yet + if not self.lnworker.enable_htlc_forwarding: + pass + elif payment_hash in self.lnworker.trampoline_forwardings: + # we are already forwarding this payment + self.logger.info(f"we are already forwarding this.") else: - # trampoline- HTLC we are supposed to forward, but haven't forwarded yet - if not self.lnworker.enable_htlc_forwarding: - return None, None, None - - if payment_hash in self.lnworker.trampoline_forwardings: - self.logger.info(f"we are already forwarding this.") - # we are already forwarding this payment - return None, True, None - - self.maybe_forward_trampoline( - payment_hash=payment_hash, - cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts - outer_onion=processed_onion, - trampoline_onion=trampoline_onion) - # return True so that this code gets executed only once + # add to list of ongoing payments + self.lnworker.trampoline_forwardings.add(payment_hash) + # clear previous failures + self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None) + async def wrapped_callback(): + forwarding_coro = forwarding_callback() + try: + await forwarding_coro + except Exception as e: + # FIXME: cannot use payment_hash as key + self.lnworker.trampoline_forwarding_failures[payment_hash] = e + finally: + # remove from list of payments, so that another attempt can be initiated + self.lnworker.trampoline_forwardings.remove(payment_hash) + asyncio.ensure_future(wrapped_callback()) return None, True, None else: # trampoline- HTLC we are supposed to forward, and have already forwarded diff --git a/electrum/lnworker.py b/electrum/lnworker.py index fcb142b72..caa4bf852 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1922,16 +1922,16 @@ class LNWallet(LNWorker): if write_to_disk: self.wallet.save_db() - def check_received_htlc( - self, payment_secret: bytes, - short_channel_id: ShortChannelID, - htlc: UpdateAddHtlc, - expected_msat: int, + + def check_mpp_status( + self, payment_secret: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + expected_msat: int, ) -> Optional[bool]: """ return MPP status: True (accepted), False (expired) or None (waiting) """ payment_hash = htlc.payment_hash - self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat) is_expired, is_accepted = self.get_mpp_status(payment_secret) if not is_accepted and not is_expired: @@ -1944,19 +1944,7 @@ class LNWallet(LNWorker): elif self.stopping_soon: is_expired = True # try to time out pending HTLCs before shutting down elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]): - preimage = self.get_preimage(payment_hash) - hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash) - if not preimage and hold_invoice_callback: - # for hold invoices, trigger callback - cb, timeout = hold_invoice_callback - if int(time.time()) < timeout: - cb(payment_hash) - else: - is_expired = True - else: - # note: preimage will be None for outer trampoline onion - is_accepted = True - + is_accepted = True elif time.time() - first_timestamp > self.MPP_EXPIRY: is_expired = True diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 86c1c0e42..7b8a805d7 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -251,7 +251,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): set_request_status = LNWallet.set_request_status set_payment_status = LNWallet.set_payment_status get_payment_status = LNWallet.get_payment_status - check_received_htlc = LNWallet.check_received_htlc + check_mpp_status = LNWallet.check_mpp_status htlc_fulfilled = LNWallet.htlc_fulfilled htlc_failed = LNWallet.htlc_failed save_preimage = LNWallet.save_preimage @@ -764,7 +764,7 @@ class TestPeer(ElectrumTestCase): if test_hold_invoice: payment_hash = lnaddr.paymenthash preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) - def cb(payment_hash): + async def cb(payment_hash): if not test_hold_timeout: w2.save_preimage(payment_hash, preimage) timeout = 1 if test_hold_timeout else 60