Browse Source

test_lnpeer: refactor tests for hold invoices

master
ThomasV 3 years ago
parent
commit
14efb401d6
  1. 44
      electrum/tests/test_lnpeer.py

44
electrum/tests/test_lnpeer.py

@ -721,7 +721,7 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(SuccessfulTest):
await f()
async def _test_simple_payment(self, trampoline: bool):
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
"""Alice pays Bob a single HTLC via direct channel."""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
@ -738,15 +738,6 @@ class TestPeer(ElectrumTestCase):
raise PaymentDone()
else:
raise PaymentFailure()
async def f(test_hold_invoice=False, test_timeout=False):
if trampoline:
await turn_on_trampoline_alice()
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
@ -756,6 +747,16 @@ class TestPeer(ElectrumTestCase):
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
async def f():
if trampoline:
await turn_on_trampoline_alice()
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
@ -763,20 +764,23 @@ class TestPeer(ElectrumTestCase):
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=False)
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=True, test_timeout=False)
with self.assertRaises(PaymentFailure):
await f(test_hold_invoice=True, test_timeout=True)
await f()
@needs_test_with_all_chacha20_implementations
async def test_simple_payment(self):
await self._test_simple_payment(trampoline=False)
for trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline)
@needs_test_with_all_chacha20_implementations
async def test_simple_payment_trampoline(self):
await self._test_simple_payment(trampoline=True)
async def test_simple_payment_with_hold_invoice(self):
for trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
async def test_simple_payment_with_hold_invoice_timing_out(self):
for trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True, test_timeout=True)
@needs_test_with_all_chacha20_implementations
async def test_payment_race(self):

Loading…
Cancel
Save