From b645da65148c45cf791bfef27db88d29b33496ba Mon Sep 17 00:00:00 2001 From: ThomasV Date: Wed, 11 Oct 2023 09:33:00 +0200 Subject: [PATCH] test_lnpeer: separate graph creation from run_trampoline --- electrum/tests/test_lnpeer.py | 50 +++++++++++++++++------------------ 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index a118f44eb..f77f90801 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -1799,13 +1799,9 @@ class TestPeerForwarding(TestPeer): await f() async def _run_trampoline_payment( - self, *, - is_legacy=False, - direct=False, - test_mpp_consolidation=False, - include_routing_hints=True, # only relevant if is_legacy is True - attempts=2, - ): + self, graph, *, + include_routing_hints=True, + attempts=2): async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) @@ -1827,6 +1823,17 @@ class TestPeerForwarding(TestPeer): lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=include_routing_hints) await group.spawn(pay(lnaddr, pay_req)) + peers = graph.peers.values() + + # declare routing nodes as trampoline nodes + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey), + graph.workers['carol'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['carol'].node_keypair.pubkey), + } + + await f() + + def create_square_graph(self, *, direct=False, test_mpp_consolidation=False, is_legacy=False): graph_definition = self.GRAPH_DEFINITIONS['square_graph'] if not direct: # deplete channel from alice to carol and from bob to dave @@ -1835,45 +1842,38 @@ class TestPeerForwarding(TestPeer): # insert a channel from bob to carol graph_definition['bob']['channels']['carol'] = low_fee_channel # now the only route possible is alice -> bob -> carol -> dave - if test_mpp_consolidation: # deplete alice to carol so that all htlcs go through bob graph_definition['alice']['channels']['carol'] = depleted_channel - graph = self.prepare_chans_and_peers_in_graph(graph_definition) - if test_mpp_consolidation: graph.workers['dave'].features |= LnFeatures.BASIC_MPP_OPT graph.workers['alice'].network.config.TEST_FORCE_MPP = True - - peers = graph.peers.values() if is_legacy: # turn off trampoline features in invoice graph.workers['dave'].features = graph.workers['dave'].features ^ LnFeatures.OPTION_TRAMPOLINE_ROUTING_OPT_ELECTRUM - - # declare routing nodes as trampoline nodes - electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { - graph.workers['bob'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['bob'].node_keypair.pubkey), - graph.workers['carol'].name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.workers['carol'].node_keypair.pubkey), - } - - await f() + return graph async def test_trampoline_mpp_consolidation(self): with self.assertRaises(PaymentDone): - await self._run_trampoline_payment(is_legacy=True, direct=False, test_mpp_consolidation=True) + graph = self.create_square_graph(direct=False, test_mpp_consolidation=True, is_legacy=True) + await self._run_trampoline_payment(graph) async def test_payment_trampoline_legacy(self): with self.assertRaises(PaymentDone): - await self._run_trampoline_payment(is_legacy=True, direct=False, include_routing_hints=True) + graph = self.create_square_graph(direct=False, is_legacy=True) + await self._run_trampoline_payment(graph, include_routing_hints=True) with self.assertRaises(NoPathFound): - await self._run_trampoline_payment(is_legacy=True, direct=False, include_routing_hints=False) + graph = self.create_square_graph(direct=False, is_legacy=True) + await self._run_trampoline_payment(graph, include_routing_hints=False) async def test_payment_trampoline_e2e_direct(self): with self.assertRaises(PaymentDone): - await self._run_trampoline_payment(is_legacy=False, direct=True) + graph = self.create_square_graph(direct=True, is_legacy=False) + await self._run_trampoline_payment(graph) async def test_payment_trampoline_e2e_indirect(self): # must use two trampolines with self.assertRaises(PaymentDone): - await self._run_trampoline_payment(is_legacy=False, direct=False) + graph = self.create_square_graph(direct=False, is_legacy=False) + await self._run_trampoline_payment(graph)