diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 69ab48c02..54c4846c4 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2605,6 +2605,12 @@ class Peer(Logger, EventListener): return closing_tx.txid() async def htlc_switch(self): + # In this loop, an item of chan.unfulfilled_htlcs may go through 4 stages: + # - 1. not forwarded yet: (None, onion_packet_hex) + # - 2. forwarded: (forwarding_key, onion_packet_hex) + # - 3. processed: (forwarding_key, None), not irrevocably removed yet + # - 4. done: (forwarding_key, None), irrevocably removed + await self.initialized while True: await self.ping_if_required() @@ -2632,6 +2638,16 @@ class Peer(Logger, EventListener): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): continue htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) + if chan.hm.is_htlc_irrevocably_removed_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): + assert onion_packet_hex is None + self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) + if forwarding_key: + self.lnworker.maybe_cleanup_forwarding(forwarding_key) + done.add(htlc_id) + continue + if onion_packet_hex is None: + # has been processed already + continue error_reason = None # type: Optional[OnionRoutingFailure] error_bytes = None # type: Optional[bytes] preimage = None @@ -2673,7 +2689,9 @@ class Peer(Logger, EventListener): chan=chan, htlc_id=htlc.htlc_id, reason=error_reason) - done.add(htlc_id) + # blank onion field to mark it as processed + unfulfilled[htlc_id] = None, forwarding_key + # cleanup for htlc_id in done: unfulfilled.pop(htlc_id) @@ -2760,7 +2778,6 @@ class Peer(Logger, EventListener): # return payment_key so this branch will not be executed again return None, payment_key, None elif preimage: - self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) return preimage, None, None else: # we are waiting for mpp consolidation or preimage @@ -2771,11 +2788,6 @@ class Peer(Logger, EventListener): payment_key = forwarding_key preimage = self.lnworker.get_preimage(payment_hash) error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) - if error_bytes or error_reason or preimage: - cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) - is_htlc_key = ':' in payment_key - if is_htlc_key or payment_key in cleanup_keys: - self.lnworker.maybe_cleanup_forwarding(payment_key) if error_bytes: return None, None, error_bytes if error_reason: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 8b1d36097..5ffc6cc25 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -2348,9 +2348,9 @@ class LNWallet(LNWorker): self, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, - ) -> Sequence[str]: + ) -> None: + htlc_key = (short_channel_id, htlc) - cleanup_keys = [] for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): if htlc_key not in mpp_status.htlc_set: continue @@ -2360,8 +2360,7 @@ class LNWallet(LNWorker): if len(mpp_status.htlc_set) == 0: self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') self.received_mpp_htlcs.pop(payment_key_hex) - cleanup_keys.append(payment_key_hex) - return cleanup_keys + self.maybe_cleanup_forwarding(payment_key_hex) def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: self.active_forwardings.pop(payment_key_hex, None) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index c674e8897..529d03a5d 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -1879,23 +1879,25 @@ class TestPeerForwarding(TestPeer): test_failure=False, attempts=2): - bob_w = graph.workers['bob'] - carol_w = graph.workers['carol'] + alice_w = graph.workers['alice'] dave_w = graph.workers['dave'] async def pay(lnaddr, pay_req): - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) - result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) + self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) + result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) + async with OldTaskGroup() as g: + for peer in peers: + await g.spawn(peer.wait_one_htlc_switch_iteration()) + for peer in peers: + self.assertEqual(len(peer.lnworker.active_forwardings), 0) if result: - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) - self.assertFalse(bool(bob_w.active_forwardings)) - self.assertFalse(bool(carol_w.active_forwardings)) + self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() else: raise NoPathFound() async def f(): - await self._activate_trampoline(graph.workers['alice']) + await self._activate_trampoline(alice_w) async with OldTaskGroup() as group: for peer in peers: await group.spawn(peer._message_loop())