Browse Source

test_lnpeer: add more tests, combining forwarding and hold invoices

master
ThomasV 2 years ago
parent
commit
4a6c55ef7b
  1. 71
      electrum/tests/test_lnpeer.py

71
electrum/tests/test_lnpeer.py

@ -500,6 +500,18 @@ class TestPeer(ElectrumTestCase):
await w.network.channel_db.stopped_event.wait()
w.network.channel_db = None
def prepare_recipient(self, w2, payment_hash, test_hold_invoice, test_failure):
if not test_hold_invoice and not test_failure:
return
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
if test_hold_invoice:
async def cb(payment_hash):
if not test_failure:
w2.save_preimage(payment_hash, preimage)
else:
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
w2.register_hold_invoice(payment_hash, cb)
class TestPeerDirect(TestPeer):
@ -700,8 +712,8 @@ class TestPeerDirect(TestPeer):
async def _test_simple_payment(
self,
test_trampoline: bool,
test_failure:bool=False,
test_hold_invoice=False,
test_failure=False,
test_bundle=False,
test_bundle_timeout=False
):
@ -717,16 +729,7 @@ class TestPeerDirect(TestPeer):
else:
raise PaymentFailure()
lnaddr, pay_req = self.prepare_invoice(w2)
payment_hash = lnaddr.paymenthash
if test_failure or test_hold_invoice:
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
if test_hold_invoice:
async def cb(payment_hash):
if not test_failure:
w2.save_preimage(payment_hash, preimage)
else:
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
w2.register_hold_invoice(payment_hash, cb)
self.prepare_recipient(w2, lnaddr.paymenthash, test_hold_invoice, test_failure)
if test_bundle:
lnaddr2, pay_req2 = self.prepare_invoice(w2)
@ -1681,26 +1684,32 @@ class TestPeerForwarding(TestPeer):
bob_forwarding=True,
mpp_invoice=True,
disable_trampoline_receiving=False,
test_hold_invoice=False,
test_failure=False,
):
alice_w = graph.workers['alice']
bob_w = graph.workers['bob']
dave_w = graph.workers['dave']
if mpp_invoice:
graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT
dave_w.features |= LnFeatures.BASIC_MPP_OPT
if disable_trampoline_receiving:
graph.workers['dave'].features &= ~LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
dave_w.features &= ~LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM
if not bob_forwarding:
graph.workers['bob'].enable_htlc_forwarding = False
bob_w.enable_htlc_forwarding = False
if alice_uses_trampoline:
await self._activate_trampoline(graph.workers['alice'])
await self._activate_trampoline(alice_w)
else:
assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
assert alice_w.network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=True, amount_msat=amount_to_pay)
self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure)
self.assertEqual(PR_UNPAID, dave_w.get_payment_status(lnaddr.paymenthash))
result, log = await alice_w.pay_invoice(pay_req, attempts=attempts)
if not bob_forwarding:
# reset to previous state, sleep 2s so that the second htlc can time out
graph.workers['bob'].enable_htlc_forwarding = True
await asyncio.sleep(2)
if result:
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentTimeout()
@ -1720,6 +1729,11 @@ class TestPeerForwarding(TestPeer):
with self.assertRaises(PaymentDone):
await self._run_mpp(graph, {})
async def test_payment_multipart_with_hold_invoice(self):
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph'])
with self.assertRaises(PaymentDone):
await self._run_mpp(graph, {'test_hold_invoice': True})
async def test_payment_multipart_with_timeout(self):
graph = self.prepare_chans_and_peers_in_graph(self.GRAPH_DEFINITIONS['square_graph'])
with self.assertRaises(PaymentTimeout):
@ -1801,8 +1815,12 @@ class TestPeerForwarding(TestPeer):
async def _run_trampoline_payment(
self, graph, *,
include_routing_hints=True,
test_hold_invoice=False,
test_failure=False,
attempts=2):
dave_w = graph.workers['dave']
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
@ -1820,7 +1838,8 @@ class TestPeerForwarding(TestPeer):
await group.spawn(peer.htlc_switch())
for peer in peers:
await peer.initialized
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=include_routing_hints)
lnaddr, pay_req = self.prepare_invoice(dave_w, include_routing_hints=include_routing_hints)
self.prepare_recipient(dave_w, lnaddr.paymenthash, test_hold_invoice, test_failure)
await group.spawn(pay(lnaddr, pay_req))
peers = graph.peers.values()
@ -1859,6 +1878,16 @@ class TestPeerForwarding(TestPeer):
graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True)
await self._run_trampoline_payment(graph)
async def test_trampoline_mpp_consolidation_with_hold_invoice(self):
with self.assertRaises(PaymentDone):
graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True)
await self._run_trampoline_payment(graph, test_hold_invoice=True)
async def test_trampoline_mpp_consolidation_with_hold_invoice_failure(self):
with self.assertRaises(NoPathFound):
graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True)
await self._run_trampoline_payment(graph, test_hold_invoice=True, test_failure=True)
async def test_payment_trampoline_legacy(self):
with self.assertRaises(PaymentDone):
graph = self.create_square_graph(direct=False, is_legacy=True)

Loading…
Cancel
Save