Browse Source

lnworker: add support for hold invoices

(invoices for which we do not have the preimage)

Callbacks and timeouts are registered with lnworker. If the
preimage is not known after the timeout has expired, the payment
is failed with MPP_TIMEOUT.
master
ThomasV 3 years ago
parent
commit
1acf426fa9
  1. 20
      electrum/lnworker.py
  2. 28
      electrum/tests/test_lnpeer.py

20
electrum/lnworker.py

@ -676,6 +676,7 @@ class LNWallet(LNWorker):
self.trampoline_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
def has_deterministic_node_id(self) -> bool:
return bool(self.db.get('lightning_xprv'))
@ -1880,6 +1881,14 @@ class LNWallet(LNWorker):
amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat):
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
self.save_payment_info(info, write_to_disk=False)
def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None):
expiry = int(time.time()) + timeout
self.hold_invoice_callbacks[payment_hash] = cb, expiry
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
key = info.payment_hash.hex()
assert info.status in SAVED_PR_STATUS
@ -1891,13 +1900,22 @@ class LNWallet(LNWorker):
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
payment_hash = htlc.payment_hash
preimage = self.get_preimage(payment_hash)
callback = self.hold_invoice_callbacks.get(payment_hash)
if not preimage and callback:
cb, timeout = callback
if int(time.time()) < timeout:
cb(payment_hash)
return None
else:
return False
amt_to_forward = htlc.amount_msat # check this
if amt_to_forward >= expected_msat:
# not multi-part
return True
payment_hash = htlc.payment_hash
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted

28
electrum/tests/test_lnpeer.py

@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.preimages = {}
self.stopping_soon = False
self.downstream_htlc_to_upstream_peer_map = {}
self.hold_invoice_callbacks = {}
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
_force_close_channel = LNWallet._force_close_channel
suggest_splits = LNWallet.suggest_splits
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
class MockTransport:
def __init__(self, name):
@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase):
async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
result, log = await w1.pay_invoice(pay_req)
self.assertTrue(result)
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
async def f():
if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
raise PaymentDone()
else:
raise PaymentFailure()
async def f(test_hold_invoice=False, test_timeout=False):
if trampoline:
await turn_on_trampoline_alice()
async with OldTaskGroup() as group:
@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase):
await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01)
lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
def cb(payment_hash):
if not test_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase):
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
with self.assertRaises(PaymentDone):
await f()
await f(test_hold_invoice=False)
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=True, test_timeout=False)
with self.assertRaises(PaymentFailure):
await f(test_hold_invoice=True, test_timeout=True)
@needs_test_with_all_chacha20_implementations
async def test_simple_payment(self):

Loading…
Cancel
Save