From bf86cd67616ac78573e8d4af2adc657a38a7f2d4 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Wed, 9 Aug 2023 11:44:41 +0200 Subject: [PATCH] lnpeer and lnworker cleanup: - rename trampoline_forwardings -> final_onion_forwardings, because this dict is used for both trampoline and hold invoices - remove timeout from hold_invoice_callbacks (redundant with invoice) - add test_failure boolean parameter to TestPeer._test_simple_payment, in order to test correct propagation of OnionRoutingFailures. - maybe_fulfill_htlc: raise an OnionRoutingFailure if we do not have the preimage for a payment that does not have a hold invoice callback. Without this, the above unit tests stall when we use test_failure=True --- electrum/lnpeer.py | 23 +++++++++-------------- electrum/lnworker.py | 15 ++++++--------- electrum/submarine_swaps.py | 4 ++-- electrum/tests/test_lnpeer.py | 35 +++++++++++++++++++++++++---------- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index b5f7843ff..b3f058cc5 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1887,12 +1887,7 @@ class Peer(Logger): if preimage: return preimage, None else: - # for hold invoices, trigger callback - cb, timeout = hold_invoice_callback - if int(time.time()) < timeout: - return None, lambda: cb(payment_hash) - else: - raise exc_incorrect_or_unknown_pd + return None, lambda: hold_invoice_callback(payment_hash) # TODO don't accept payments twice for same invoice # TODO check invoice expiry @@ -1903,8 +1898,8 @@ class Peer(Logger): preimage = self.lnworker.get_preimage(payment_hash) if not preimage: - self.logger.info(f"missing callback {payment_hash.hex()}") - return None, None + self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}") + raise exc_incorrect_or_unknown_pd expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)] expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices @@ -2424,23 +2419,23 @@ class Peer(Logger): # trampoline- HTLC we are supposed to forward, but haven't forwarded yet if not self.lnworker.enable_htlc_forwarding: pass - elif payment_key in self.lnworker.trampoline_forwardings: + elif payment_key in self.lnworker.final_onion_forwardings: # we are already forwarding this payment self.logger.info(f"we are already forwarding this.") else: # add to list of ongoing payments - self.lnworker.trampoline_forwardings.add(payment_key) + self.lnworker.final_onion_forwardings.add(payment_key) # clear previous failures - self.lnworker.trampoline_forwarding_failures.pop(payment_key, None) + self.lnworker.final_onion_forwarding_failures.pop(payment_key, None) async def wrapped_callback(): forwarding_coro = forwarding_callback() try: await forwarding_coro except OnionRoutingFailure as e: - self.lnworker.trampoline_forwarding_failures[payment_key] = e + self.lnworker.final_onion_forwarding_failures[payment_key] = e finally: # remove from list of payments, so that another attempt can be initiated - self.lnworker.trampoline_forwardings.remove(payment_key) + self.lnworker.final_onion_forwardings.remove(payment_key) asyncio.ensure_future(wrapped_callback()) fw_info = payment_key.hex() return None, fw_info, None @@ -2449,7 +2444,7 @@ class Peer(Logger): payment_key = bytes.fromhex(forwarding_info) preimage = self.lnworker.get_preimage(payment_hash) # get (and not pop) failure because the incoming payment might be multi-part - error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_key) + error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key) if error_reason: self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}') raise error_reason diff --git a/electrum/lnworker.py b/electrum/lnworker.py index eb33114c3..44170d4d4 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -703,8 +703,8 @@ class LNWallet(LNWorker): for payment_hash in self.get_payments(status='inflight').keys(): self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) - self.trampoline_forwardings = set() - self.trampoline_forwarding_failures = {} # todo: should be persisted + self.final_onion_forwardings = set() + self.final_onion_forwarding_failures = {} # todo: should be persisted # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] # payment_hash -> callback, timeout: @@ -1954,11 +1954,8 @@ class LNWallet(LNWorker): info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID) self.save_payment_info(info, write_to_disk=False) - def register_callback_for_hold_invoice( - self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int, - ): - expiry = int(time.time()) + timeout - self.hold_invoice_callbacks[payment_hash] = cb, expiry + def register_callback_for_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], None]): + self.hold_invoice_callbacks[payment_hash] = cb def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: key = info.payment_hash.hex() @@ -2758,7 +2755,7 @@ class LNWallet(LNWorker): util.trigger_callback('channels_updated', self.wallet) self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) - def fail_trampoline_forwarding(self, payment_key): + def fail_final_onion_forwarding(self, payment_key): """ use this to fail htlcs received for hold invoices""" e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') - self.trampoline_forwarding_failures[payment_key] = e + self.final_onion_forwarding_failures[payment_key] = e diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 5d40a72fb..17ea412e2 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -247,7 +247,7 @@ class SwapManager(Logger): self.logger.info(f'found confirmed refund') payment_secret = self.lnworker.get_payment_secret(swap.payment_hash) payment_key = swap.payment_hash + payment_secret - self.lnworker.fail_trampoline_forwarding(payment_key) + self.lnworker.fail_final_onion_forwarding(payment_key) if delta < 0: # too early for refund @@ -343,7 +343,7 @@ class SwapManager(Logger): ) # add payment info to lnworker self.lnworker.add_payment_info_for_hold_invoice(payment_hash, main_amount_sat) - self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback, 60*60*24) + self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback) prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000) _, prepay_invoice = self.lnworker.get_bolt11_invoice( payment_hash=prepay_hash, diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 0bd463fcc..544e0d9ed 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -36,7 +36,7 @@ from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED -from electrum.lnonion import OnionFailureCode +from electrum.lnonion import OnionFailureCode, OnionRoutingFailure from electrum.lnutil import UpdateAddHtlc from electrum.lnutil import LOCAL, REMOTE from electrum.invoices import PR_PAID, PR_UNPAID @@ -169,8 +169,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.sent_htlcs_q = defaultdict(asyncio.Queue) self.sent_htlcs_info = dict() self.sent_buckets = defaultdict(set) - self.trampoline_forwardings = set() - self.trampoline_forwarding_failures = {} + self.final_onion_forwardings = set() + self.final_onion_forwarding_failures = {} self.inflight_payments = set() self.preimages = {} self.stopping_soon = False @@ -749,6 +749,7 @@ class TestPeer(ElectrumTestCase): async def _test_simple_payment( self, test_trampoline: bool, + test_failure:bool=False, test_hold_invoice=False, test_bundle=False, test_bundle_timeout=False @@ -765,12 +766,16 @@ class TestPeer(ElectrumTestCase): else: raise PaymentFailure() lnaddr, pay_req = self.prepare_invoice(w2) - if test_hold_invoice: - payment_hash = lnaddr.paymenthash + payment_hash = lnaddr.paymenthash + if test_failure or test_hold_invoice: preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) - async def cb(payment_hash): - w2.save_preimage(payment_hash, preimage) - w2.register_callback_for_hold_invoice(payment_hash, cb, 60) + if test_hold_invoice: + async def cb(payment_hash): + if not test_failure: + w2.save_preimage(payment_hash, preimage) + else: + raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') + w2.register_callback_for_hold_invoice(payment_hash, cb) if test_bundle: lnaddr2, pay_req2 = self.prepare_invoice(w2) @@ -799,11 +804,16 @@ class TestPeer(ElectrumTestCase): await f() @needs_test_with_all_chacha20_implementations - async def test_simple_payment(self): + async def test_simple_payment_success(self): for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): await self._test_simple_payment(test_trampoline=test_trampoline) + async def test_simple_payment_failure(self): + for test_trampoline in [False, True]: + with self.assertRaises(PaymentFailure): + await self._test_simple_payment(test_trampoline=test_trampoline, test_failure=True) + async def test_payment_bundle(self): for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): @@ -814,11 +824,16 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(PaymentFailure): await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True) - async def test_simple_payment_with_hold_invoice(self): + async def test_simple_payment_success_with_hold_invoice(self): for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True) + async def test_simple_payment_failure_with_hold_invoice(self): + for test_trampoline in [False, True]: + with self.assertRaises(PaymentFailure): + await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True) + @needs_test_with_all_chacha20_implementations async def test_payment_race(self): """Alice and Bob pay each other simultaneously.