From 1acf426fa9077e4537177938bd5ba8f7f6f77b4e Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 15 Jun 2023 12:13:35 +0200 Subject: [PATCH] lnworker: add support for hold invoices (invoices for which we do not have the preimage) Callbacks and timeouts are registered with lnworker. If the preimage is not known after the timeout has expired, the payment is failed with MPP_TIMEOUT. --- electrum/lnworker.py | 20 +++++++++++++++++++- electrum/tests/test_lnpeer.py | 28 ++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6e6f72b97..b76a50680 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -676,6 +676,7 @@ class LNWallet(LNWorker): self.trampoline_forwarding_failures = {} # todo: should be persisted # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] + self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout def has_deterministic_node_id(self) -> bool: return bool(self.db.get('lightning_xprv')) @@ -1880,6 +1881,14 @@ class LNWallet(LNWorker): amount_msat, direction, status = self.payment_info[key] return PaymentInfo(payment_hash, amount_msat, direction, status) + def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat): + info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID) + self.save_payment_info(info, write_to_disk=False) + + def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None): + expiry = int(time.time()) + timeout + self.hold_invoice_callbacks[payment_hash] = cb, expiry + def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: key = info.payment_hash.hex() assert info.status in SAVED_PR_STATUS @@ -1891,13 +1900,22 @@ class LNWallet(LNWorker): def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: """ return MPP status: True (accepted), False (expired) or None (waiting) """ + payment_hash = htlc.payment_hash + preimage = self.get_preimage(payment_hash) + callback = self.hold_invoice_callbacks.get(payment_hash) + if not preimage and callback: + cb, timeout = callback + if int(time.time()) < timeout: + cb(payment_hash) + return None + else: + return False amt_to_forward = htlc.amount_msat # check this if amt_to_forward >= expected_msat: # not multi-part return True - payment_hash = htlc.payment_hash is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) if self.get_payment_status(payment_hash) == PR_PAID: # payment_status is persisted diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 9a927dd7e..a383e9d73 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.preimages = {} self.stopping_soon = False self.downstream_htlc_to_upstream_peer_map = {} + self.hold_invoice_callbacks = {} self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") @@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): _on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved _force_close_channel = LNWallet._force_close_channel suggest_splits = LNWallet.suggest_splits - + register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice + add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice class MockTransport: def __init__(self, name): @@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase): async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) result, log = await w1.pay_invoice(pay_req) - self.assertTrue(result) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) - raise PaymentDone() - async def f(): + if result is True: + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + raise PaymentDone() + else: + raise PaymentFailure() + async def f(test_hold_invoice=False, test_timeout=False): if trampoline: await turn_on_trampoline_alice() async with OldTaskGroup() as group: @@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase): await group.spawn(p2.htlc_switch()) await asyncio.sleep(0.01) lnaddr, pay_req = self.prepare_invoice(w2) + if test_hold_invoice: + payment_hash = lnaddr.paymenthash + preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) + def cb(payment_hash): + if not test_timeout: + w2.save_preimage(payment_hash, preimage) + timeout = 1 if test_timeout else 60 + w2.register_callback_for_hold_invoice(payment_hash, cb, timeout) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) await group.spawn(pay(lnaddr, pay_req)) @@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase): 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), } with self.assertRaises(PaymentDone): - await f() + await f(test_hold_invoice=False) + with self.assertRaises(PaymentDone): + await f(test_hold_invoice=True, test_timeout=False) + with self.assertRaises(PaymentFailure): + await f(test_hold_invoice=True, test_timeout=True) @needs_test_with_all_chacha20_implementations async def test_simple_payment(self):