Browse Source

test_lnpeer: refactor tests for hold invoices

master
ThomasV 2 years ago
parent
commit
14efb401d6
  1. 46
      electrum/tests/test_lnpeer.py

46
electrum/tests/test_lnpeer.py

@ -721,7 +721,7 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(SuccessfulTest):
await f()
async def _test_simple_payment(self, trampoline: bool):
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
"""Alice pays Bob a single HTLC via direct channel."""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
@ -738,7 +738,17 @@ class TestPeer(ElectrumTestCase):
raise PaymentDone()
else:
raise PaymentFailure()
async def f(test_hold_invoice=False, test_timeout=False):
lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
def cb(payment_hash):
if not test_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
async def f():
if trampoline:
await turn_on_trampoline_alice()
async with OldTaskGroup() as group:
@ -747,15 +757,6 @@ class TestPeer(ElectrumTestCase):
await group.spawn(p2._message_loop())
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
def cb(payment_hash):
if not test_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
@ -763,20 +764,23 @@ class TestPeer(ElectrumTestCase):
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=False)
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=True, test_timeout=False)
with self.assertRaises(PaymentFailure):
await f(test_hold_invoice=True, test_timeout=True)
await f()
@needs_test_with_all_chacha20_implementations
async def test_simple_payment(self):
await self._test_simple_payment(trampoline=False)
for trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline)
@needs_test_with_all_chacha20_implementations
async def test_simple_payment_trampoline(self):
await self._test_simple_payment(trampoline=True)
async def test_simple_payment_with_hold_invoice(self):
for trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
async def test_simple_payment_with_hold_invoice_timing_out(self):
for trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True, test_timeout=True)
@needs_test_with_all_chacha20_implementations
async def test_payment_race(self):

Loading…
Cancel
Save