diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 5e155b49e..811abe1c6 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -578,6 +578,7 @@ class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] MPP_EXPIRY = 120 + TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds def __init__(self, wallet: 'Abstract_Wallet', xprv): self.wallet = wallet @@ -713,11 +714,12 @@ class LNWallet(LNWorker): self.stopping_soon = True if self.listen_server: # stop accepting new peers self.listen_server.close() - async with ignore_after(3): + async with ignore_after(self.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS): await self.wait_for_received_pending_htlcs_to_get_removed() - await super().stop() - await self.lnwatcher.stop() - self.lnwatcher = None + await LNWorker.stop(self) + if self.lnwatcher: + await self.lnwatcher.stop() + self.lnwatcher = None async def wait_for_received_pending_htlcs_to_get_removed(self): assert self.stopping_soon is True diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 2c0005bd5..f7386fbbc 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -10,7 +10,7 @@ from concurrent import futures import unittest from typing import Iterable, NamedTuple, Tuple, List -from aiorpcx import TaskGroup +from aiorpcx import TaskGroup, timeout_after, TaskTimeout from electrum import bitcoin from electrum import constants @@ -113,7 +113,8 @@ class MockWallet: class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): - MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 + MPP_EXPIRY = 2 # HTLC timestamps are cast to int, so this cannot be 1 + TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 0 def __init__(self, *, local_keypair: Keypair, chans: Iterable['Channel'], tx_queue, name): self.name = name @@ -121,6 +122,9 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): NetworkRetryManager.__init__(self, max_retry_delay_normal=1, init_retry_delay_normal=1) self.node_keypair = local_keypair self.network = MockNetwork(tx_queue) + self.taskgroup = TaskGroup() + self.lnwatcher = None + self.listen_server = None self._channels = {chan.channel_id: chan for chan in chans} self.payments = {} self.logs = defaultdict(list) @@ -184,6 +188,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): return self.name async def stop(self): + await LNWallet.stop(self) if self.channel_db: self.channel_db.stop() await self.channel_db.stopped_event.wait() @@ -216,6 +221,8 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): _calc_routing_hints_for_invoice = LNWallet._calc_routing_hints_for_invoice handle_error_code_from_failed_htlc = LNWallet.handle_error_code_from_failed_htlc is_trampoline_peer = LNWallet.is_trampoline_peer + wait_for_received_pending_htlcs_to_get_removed = LNWallet.wait_for_received_pending_htlcs_to_get_removed + on_proxy_changed = LNWallet.on_proxy_changed class MockTransport: @@ -291,13 +298,9 @@ class SquareGraph(NamedTuple): def all_lnworkers(self) -> Iterable[MockLNWallet]: return self.w_a, self.w_b, self.w_c, self.w_d - async def stop_and_cleanup(self): - async with TaskGroup() as group: - for lnworker in self.all_lnworkers(): - await group.spawn(lnworker.stop()) - class PaymentDone(Exception): pass +class TestSuccess(Exception): pass class TestPeer(ElectrumTestCase): @@ -837,6 +840,50 @@ class TestPeer(ElectrumTestCase): graph = self.prepare_chans_and_peers_in_square() self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':3}) + @needs_test_with_all_chacha20_implementations + def test_fail_pending_htlcs_on_shutdown(self): + """Alice tries to pay Dave via MPP. Dave receives some HTLCs but not all. + Dave shuts down (stops wallet). + We test if Dave fails the pending HTLCs during shutdown. + """ + graph = self.prepare_chans_and_peers_in_square() + self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) + self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) + amount_to_pay = 600_000_000_000 + peers = graph.all_peers() + graph.w_d.MPP_EXPIRY = 120 + graph.w_d.TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 + async def pay(): + graph.w_d.features |= LnFeatures.BASIC_MPP_OPT + graph.w_b.enable_htlc_forwarding.clear() # Bob will hold forwarded HTLCs + assert graph.w_a.network.channel_db is not None + lnaddr, pay_req = await self.prepare_invoice(graph.w_d, include_routing_hints=True, amount_msat=amount_to_pay) + try: + async with timeout_after(0.5): + result, log = await graph.w_a.pay_invoice(pay_req, attempts=1) + except TaskTimeout: + # by now Dave hopefully received some HTLCs: + self.assertTrue(len(graph.chan_dc.hm.htlcs(LOCAL)) > 0) + self.assertTrue(len(graph.chan_dc.hm.htlcs(REMOTE)) > 0) + else: + self.fail(f"pay_invoice finished but was not supposed to. result={result}") + await graph.w_d.stop() + # Dave is supposed to have failed the pending incomplete MPP HTLCs + self.assertEqual(0, len(graph.chan_dc.hm.htlcs(LOCAL))) + self.assertEqual(0, len(graph.chan_dc.hm.htlcs(REMOTE))) + raise TestSuccess() + + async def f(): + async with TaskGroup() as group: + for peer in peers: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(pay()) + + with self.assertRaises(TestSuccess): + run(f()) + @needs_test_with_all_chacha20_implementations def test_close(self): alice_channel, bob_channel = create_test_channels()