diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a383e9d73..6a77aab7f 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -721,7 +721,7 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(SuccessfulTest): await f() - async def _test_simple_payment(self, trampoline: bool): + async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False): """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) @@ -738,7 +738,17 @@ class TestPeer(ElectrumTestCase): raise PaymentDone() else: raise PaymentFailure() - async def f(test_hold_invoice=False, test_timeout=False): + lnaddr, pay_req = self.prepare_invoice(w2) + if test_hold_invoice: + payment_hash = lnaddr.paymenthash + preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) + def cb(payment_hash): + if not test_timeout: + w2.save_preimage(payment_hash, preimage) + timeout = 1 if test_timeout else 60 + w2.register_callback_for_hold_invoice(payment_hash, cb, timeout) + + async def f(): if trampoline: await turn_on_trampoline_alice() async with OldTaskGroup() as group: @@ -747,15 +757,6 @@ class TestPeer(ElectrumTestCase): await group.spawn(p2._message_loop()) await group.spawn(p2.htlc_switch()) await asyncio.sleep(0.01) - lnaddr, pay_req = self.prepare_invoice(w2) - if test_hold_invoice: - payment_hash = lnaddr.paymenthash - preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) - def cb(payment_hash): - if not test_timeout: - w2.save_preimage(payment_hash, preimage) - timeout = 1 if test_timeout else 60 - w2.register_callback_for_hold_invoice(payment_hash, cb, timeout) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) await group.spawn(pay(lnaddr, pay_req)) @@ -763,20 +764,23 @@ class TestPeer(ElectrumTestCase): electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), } - with self.assertRaises(PaymentDone): - await f(test_hold_invoice=False) - with self.assertRaises(PaymentDone): - await f(test_hold_invoice=True, test_timeout=False) - with self.assertRaises(PaymentFailure): - await f(test_hold_invoice=True, test_timeout=True) + await f() @needs_test_with_all_chacha20_implementations async def test_simple_payment(self): - await self._test_simple_payment(trampoline=False) + for trampoline in [False, True]: + with self.assertRaises(PaymentDone): + await self._test_simple_payment(trampoline=trampoline) - @needs_test_with_all_chacha20_implementations - async def test_simple_payment_trampoline(self): - await self._test_simple_payment(trampoline=True) + async def test_simple_payment_with_hold_invoice(self): + for trampoline in [False, True]: + with self.assertRaises(PaymentDone): + await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True) + + async def test_simple_payment_with_hold_invoice_timing_out(self): + for trampoline in [False, True]: + with self.assertRaises(PaymentFailure): + await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True, test_timeout=True) @needs_test_with_all_chacha20_implementations async def test_payment_race(self):