From 6819f685d7ad74ee660b7fe9c35a1255ee38cc38 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 7 Aug 2023 14:12:17 +0000 Subject: [PATCH] tests: add tests for "recv mpp confusion" bug see https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf --- electrum/tests/test_lnpeer.py | 152 ++++++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 75f5c5394..0ac776d87 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -1194,6 +1194,158 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(PaymentDone): await f() + async def test_payment_recv_mpp_confusion1(self): + """Regression test for https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf""" + # This test checks that the following attack does not work: + # - Bob creates invoice1: 1 BTC, H1, S1 + # - Bob creates invoice2: 1 BTC, H2, S2; both given to attacker to pay + # - Alice sends htlc1: 0.1 BTC, H1, S1 (total_msat=1 BTC) + # - Alice sends htlc2: 0.9 BTC, H2, S1 (total_msat=1 BTC) + # - Bob(victim) reveals preimage for H1 and fulfills htlc1 (fails other) + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + async def pay(): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr2.paymenthash)) + + route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route + p1.pay( + route=route, + chan=alice_channel, + amount_msat=1000, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), + payment_secret=lnaddr1.payment_secret, + ) + p1.pay( + route=route, + chan=alice_channel, + amount_msat=lnaddr1.get_amount_msat() - 1000, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr2.paymenthash, + min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), + payment_secret=lnaddr1.payment_secret, + ) + + while nhtlc_success + nhtlc_failed < 2: + await htlc_resolved.wait() + self.assertEqual(0, nhtlc_success) + self.assertEqual(2, nhtlc_failed) + raise SuccessfulTest() + + w2.features |= LnFeatures.BASIC_MPP_OPT + lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000) + lnaddr2, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000) + self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + self.assertTrue(lnaddr2.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(pay()) + + htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_htlc_fulfilled(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_htlc_failed(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_htlc_fulfilled) + util.unregister_callback(on_htlc_failed) + + async def test_payment_recv_mpp_confusion2(self): + """Regression test for https://github.com/spesmilo/electrum/security/advisories/GHSA-8r85-vp7r-hjxf""" + # This test checks that the following attack does not work: + # - Bob creates invoice: 1 BTC + # - Alice sends htlc1: 0.1 BTC (total_msat=0.2 BTC) + # - Alice sends htlc2: 0.1 BTC (total_msat=1 BTC) + # - Bob(victim) reveals preimage and fulfills htlc2 (fails other) + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + async def pay(): + self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr1.paymenthash)) + + route = (await w1.create_routes_from_invoice(amount_msat=1000, decoded_invoice=lnaddr1))[0][0].route + p1.pay( + route=route, + chan=alice_channel, + amount_msat=1000, + total_msat=2000, + payment_hash=lnaddr1.paymenthash, + min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), + payment_secret=lnaddr1.payment_secret, + ) + p1.pay( + route=route, + chan=alice_channel, + amount_msat=1000, + total_msat=lnaddr1.get_amount_msat(), + payment_hash=lnaddr1.paymenthash, + min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), + payment_secret=lnaddr1.payment_secret, + ) + + while nhtlc_success + nhtlc_failed < 2: + await htlc_resolved.wait() + self.assertEqual(0, nhtlc_success) + self.assertEqual(2, nhtlc_failed) + raise SuccessfulTest() + + w2.features |= LnFeatures.BASIC_MPP_OPT + lnaddr1, _pay_req = self.prepare_invoice(w2, amount_msat=100_000_000) + self.assertTrue(lnaddr1.get_features().supports(LnFeatures.BASIC_MPP_OPT)) + + async def f(): + async with OldTaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.01) + await group.spawn(pay()) + + htlc_resolved = asyncio.Event() + nhtlc_success = 0 + nhtlc_failed = 0 + async def on_htlc_fulfilled(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_success + nhtlc_success += 1 + async def on_htlc_failed(*args): + htlc_resolved.set() + htlc_resolved.clear() + nonlocal nhtlc_failed + nhtlc_failed += 1 + util.register_callback(on_htlc_fulfilled, ["htlc_fulfilled"]) + util.register_callback(on_htlc_failed, ["htlc_failed"]) + + try: + with self.assertRaises(SuccessfulTest): + await f() + finally: + util.unregister_callback(on_htlc_fulfilled) + util.unregister_callback(on_htlc_failed) + async def _run_mpp(self, graph, fail_kwargs, success_kwargs): """Tests a multipart payment scenario for failing and successful cases.""" self.assertEqual(500_000_000_000, graph.channels[('alice', 'bob')].balance(LOCAL))