Browse Source

lnworker: add support for hold invoices

(invoices for which we do not have the preimage)

Callbacks and timeouts are registered with lnworker. If the
preimage is not known after the timeout has expired, the payment
is failed with MPP_TIMEOUT.
master
ThomasV 3 years ago
parent
commit
1acf426fa9
  1. 20
      electrum/lnworker.py
  2. 24
      electrum/tests/test_lnpeer.py

20
electrum/lnworker.py

@ -676,6 +676,7 @@ class LNWallet(LNWorker):
self.trampoline_forwarding_failures = {} # todo: should be persisted self.trampoline_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
def has_deterministic_node_id(self) -> bool: def has_deterministic_node_id(self) -> bool:
return bool(self.db.get('lightning_xprv')) return bool(self.db.get('lightning_xprv'))
@ -1880,6 +1881,14 @@ class LNWallet(LNWorker):
amount_msat, direction, status = self.payment_info[key] amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status) return PaymentInfo(payment_hash, amount_msat, direction, status)
def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat):
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
self.save_payment_info(info, write_to_disk=False)
def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None):
expiry = int(time.time()) + timeout
self.hold_invoice_callbacks[payment_hash] = cb, expiry
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
key = info.payment_hash.hex() key = info.payment_hash.hex()
assert info.status in SAVED_PR_STATUS assert info.status in SAVED_PR_STATUS
@ -1891,13 +1900,22 @@ class LNWallet(LNWorker):
def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None (waiting) """ return MPP status: True (accepted), False (expired) or None (waiting)
""" """
payment_hash = htlc.payment_hash
preimage = self.get_preimage(payment_hash)
callback = self.hold_invoice_callbacks.get(payment_hash)
if not preimage and callback:
cb, timeout = callback
if int(time.time()) < timeout:
cb(payment_hash)
return None
else:
return False
amt_to_forward = htlc.amount_msat # check this amt_to_forward = htlc.amount_msat # check this
if amt_to_forward >= expected_msat: if amt_to_forward >= expected_msat:
# not multi-part # not multi-part
return True return True
payment_hash = htlc.payment_hash
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
if self.get_payment_status(payment_hash) == PR_PAID: if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted # payment_status is persisted

24
electrum/tests/test_lnpeer.py

@ -173,6 +173,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.preimages = {} self.preimages = {}
self.stopping_soon = False self.stopping_soon = False
self.downstream_htlc_to_upstream_peer_map = {} self.downstream_htlc_to_upstream_peer_map = {}
self.hold_invoice_callbacks = {}
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
@ -275,7 +276,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
_on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved _on_maybe_forwarded_htlc_resolved = LNWallet._on_maybe_forwarded_htlc_resolved
_force_close_channel = LNWallet._force_close_channel _force_close_channel = LNWallet._force_close_channel
suggest_splits = LNWallet.suggest_splits suggest_splits = LNWallet.suggest_splits
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
class MockTransport: class MockTransport:
def __init__(self, name): def __init__(self, name):
@ -731,10 +733,12 @@ class TestPeer(ElectrumTestCase):
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_UNPAID, w2.get_payment_status(lnaddr.paymenthash))
result, log = await w1.pay_invoice(pay_req) result, log = await w1.pay_invoice(pay_req)
self.assertTrue(result) if result is True:
self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, w2.get_payment_status(lnaddr.paymenthash))
raise PaymentDone() raise PaymentDone()
async def f(): else:
raise PaymentFailure()
async def f(test_hold_invoice=False, test_timeout=False):
if trampoline: if trampoline:
await turn_on_trampoline_alice() await turn_on_trampoline_alice()
async with OldTaskGroup() as group: async with OldTaskGroup() as group:
@ -744,6 +748,14 @@ class TestPeer(ElectrumTestCase):
await group.spawn(p2.htlc_switch()) await group.spawn(p2.htlc_switch())
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
lnaddr, pay_req = self.prepare_invoice(w2) lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
def cb(payment_hash):
if not test_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
invoice_features = lnaddr.get_features() invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req)) await group.spawn(pay(lnaddr, pay_req))
@ -752,7 +764,11 @@ class TestPeer(ElectrumTestCase):
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
} }
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
await f() await f(test_hold_invoice=False)
with self.assertRaises(PaymentDone):
await f(test_hold_invoice=True, test_timeout=False)
with self.assertRaises(PaymentFailure):
await f(test_hold_invoice=True, test_timeout=True)
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
async def test_simple_payment(self): async def test_simple_payment(self):

Loading…
Cancel
Save