diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 4126b1caf..9224d4923 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -500,6 +500,18 @@ class TestPeer(ElectrumTestCase): await w.network.channel_db.stopped_event.wait() w.network.channel_db = None + def prepare_recipient(self, w2, payment_hash, test_hold_invoice, test_failure): + if not test_hold_invoice and not test_failure: + return + preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) + if test_hold_invoice: + async def cb(payment_hash): + if not test_failure: + w2.save_preimage(payment_hash, preimage) + else: + raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') + w2.register_hold_invoice(payment_hash, cb) + class TestPeerDirect(TestPeer): @@ -700,8 +712,8 @@ class TestPeerDirect(TestPeer): async def _test_simple_payment( self, test_trampoline: bool, - test_failure:bool=False, test_hold_invoice=False, + test_failure=False, test_bundle=False, test_bundle_timeout=False ): @@ -717,16 +729,7 @@ class TestPeerDirect(TestPeer): else: raise PaymentFailure() lnaddr, pay_req = self.prepare_invoice(w2) - payment_hash = lnaddr.paymenthash - if test_failure or test_hold_invoice: - preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) - if test_hold_invoice: - async def cb(payment_hash): - if not test_failure: - w2.save_preimage(payment_hash, preimage) - else: - raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'') - w2.register_hold_invoice(payment_hash, cb) + self.prepare_recipient(w2, lnaddr.paymenthash, test_hold_invoice, test_failure) if test_bundle: lnaddr2, pay_req2 = self.prepare_invoice(w2) @@ -1681,26 +1684,32 @@ class TestPeerForwarding(TestPeer): bob_forwarding=True, mpp_invoice=True, disable_trampoline_receiving=False, + test_hold_invoice=False, + test_failure=False, ): + alice_w = graph.workers['alice'] + bob_w = graph.workers['bob'] + dave_w = graph.workers['dave'] if mpp_invoice: - graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT + dave_w.features |= LnFeatures.BASIC_MPP_OPT if disable_trampoline_receiving: - graph.workers['dave'].features &= ~LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM + dave_w.features &= ~LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM if not bob_forwarding: - graph.workers['bob'].enable_htlc_forwarding = False + bob_w.enable_htlc_forwarding = False if alice_uses_trampoline: - await self._activate_trampoline(graph.workers['alice']) + await self._activate_trampoline(alice_w) else: - assert graph.workers['alice'].network.channel_db is not None - lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay) - self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) - result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) + assert alice_w.network.channel_db is not None + lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay) + self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) + self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash)) + result, log = await alice_w.pay_invoice(pay_req, attempts=attempts) if not bob_forwarding: # reset to previous state, sleep 2s so that the second htlc can time out graph.workers['bob'].enable_htlc_forwarding = True await asyncio.sleep(2) if result: - self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) + self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) raise PaymentDone() elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT: raise PaymentTimeout() @@ -1720,6 +1729,11 @@ class TestPeerForwarding(TestPeer): with self.assertRaises(PaymentDone): await self._run_mpp(graph, {}) + async def test_payment_multipart_with_hold_invoice(self): + graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) + with self.assertRaises(PaymentDone): + await self._run_mpp(graph, {'test_hold_invoice': True}) + async def test_payment_multipart_with_timeout(self): graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph']) with self.assertRaises(PaymentTimeout): @@ -1801,8 +1815,12 @@ class TestPeerForwarding(TestPeer): async def _run_trampoline_payment( self, graph, *, include_routing_hints=True, + test_hold_invoice=False, + test_failure=False, attempts=2): + dave_w = graph.workers['dave'] + async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) @@ -1820,7 +1838,8 @@ class TestPeerForwarding(TestPeer): await group.spawn(peer.htlc_switch()) for peer in peers: await peer.initialized - lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=include_routing_hints) + lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=include_routing_hints) + self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure) await group.spawn(pay(lnaddr, pay_req)) peers = graph.peers.values() @@ -1859,6 +1878,16 @@ class TestPeerForwarding(TestPeer): graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) await self._run_trampoline_payment(graph) + async def test_trampoline_mpp_consolidation_with_hold_invoice(self): + with self.assertRaises(PaymentDone): + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + await self._run_trampoline_payment(graph, test_hold_invoice=True) + + async def test_trampoline_mpp_consolidation_with_hold_invoice_failure(self): + with self.assertRaises(NoPathFound): + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + await self._run_trampoline_payment(graph, test_hold_invoice=True, test_failure=True) + async def test_payment_trampoline_legacy(self): with self.assertRaises(PaymentDone): graph = self.create_square_graph(direct=False, is_legacy=True)