diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6e6f72b97..b76a50680 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -676,6 +676,7 @@ class LNWallet(LNWorker): self.trampoline_forwarding_failures = {} # todo: should be persisted # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] + self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout def has_deterministic_node_id(self) -> bool: return bool(self.db.get('lightning_xprv')) @@ -1880,6 +1881,14 @@ class LNWallet(LNWorker): amount_msat, direction, status = self.payment_info[key] return PaymentInfo(payment_hash, amount_msat, direction, status) + def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat): + info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID) + self.save_payment_info(info, write_to_disk=False) + + def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None): + expiry = int(time.time()) + timeout + self.hold_invoice_callbacks[payment_hash] = cb, expiry + def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: key = info.payment_hash.hex() assert info.status in SAVED_PR_STATUS @@ -1891,13 +1900,22 @@ class LNWallet(LNWorker): def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: """ return MPP status: True (accepted), False (expired) or None (waiting) """ + payment_hash = htlc.payment_hash + preimage = self.get_preimage(payment_hash) + callback = self.hold_invoice_callbacks.get(payment_hash) + if not preimage and callback: + cb, timeout = callback + if int(time.time()) < timeout: + cb(payment_hash) + return None + else: + return False amt_to_forward = htlc.amount_msat # check this if amt_to_forward >= expected_msat: # not multi-part return True - payment_hash = htlc.payment_hash is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) if self.get_payment_status(payment_hash) == PR_PAID: # payment_status is persisted diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 9a927dd7e..a383e9d73 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.preimages = {} self.stopping_soon = False self.downstream_htlc_to_upstream_peer_map = {} + self.hold_invoice_callbacks = {} self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") @@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): _on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved _force_close_channel = LNWallet._force_close_channel suggest_splits = LNWallet.suggest_splits - + register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice + add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice class MockTransport: def __init__(self, name): @@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase): async def pay(lnaddr, pay_req): self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) result, log = await w1.pay_invoice(pay_req) - self.assertTrue(result) - self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) - raise PaymentDone() - async def f(): + if result is True: + self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) + raise PaymentDone() + else: + raise PaymentFailure() + async def f(test_hold_invoice=False, test_timeout=False): if trampoline: await turn_on_trampoline_alice() async with OldTaskGroup() as group: @@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase): await group.spawn(p2.htlc_switch()) await asyncio.sleep(0.01) lnaddr, pay_req = self.prepare_invoice(w2) + if test_hold_invoice: + payment_hash = lnaddr.paymenthash + preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) + def cb(payment_hash): + if not test_timeout: + w2.save_preimage(payment_hash, preimage) + timeout = 1 if test_timeout else 60 + w2.register_callback_for_hold_invoice(payment_hash, cb, timeout) invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) await group.spawn(pay(lnaddr, pay_req)) @@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase): 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), } with self.assertRaises(PaymentDone): - await f() + await f(test_hold_invoice=False) + with self.assertRaises(PaymentDone): + await f(test_hold_invoice=True, test_timeout=False) + with self.assertRaises(PaymentFailure): + await f(test_hold_invoice=True, test_timeout=True) @needs_test_with_all_chacha20_implementations async def test_simple_payment(self):