Browse Source

lnworker: bundled payments

- htlcs of bundled payments must arrive in the same MPP_TIMEOUT
window, or they will be failed
- add correspoding tests
master
ThomasV 3 years ago
parent
commit
c4eb7d8321
  1. 112
      electrum/lnworker.py
  2. 60
      electrum/tests/test_lnpeer.py

112
electrum/lnworker.py

@ -677,6 +677,8 @@ class LNWallet(LNWorker):
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout
self.payment_bundles = [] # lists of hashes. todo:persist
def has_deterministic_node_id(self) -> bool:
return bool(self.db.get('lightning_xprv'))
@ -1862,6 +1864,14 @@ class LNWallet(LNWorker):
self.wallet.save_db()
return payment_hash
def bundle_payments(self, hash_list):
self.payment_bundles.append(hash_list)
def get_payment_bundle(self, payment_hash):
for hash_list in self.payment_bundles:
if payment_hash in hash_list:
return hash_list
def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True):
assert sha256(preimage) == payment_hash
self.preimages[payment_hash.hex()] = preimage.hex()
@ -1901,45 +1911,87 @@ class LNWallet(LNWorker):
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
payment_hash = htlc.payment_hash
preimage = self.get_preimage(payment_hash)
callback = self.hold_invoice_callbacks.get(payment_hash)
if not preimage and callback:
cb, timeout = callback
if int(time.time()) < timeout:
cb(payment_hash)
return None
else:
return False
amt_to_forward = htlc.amount_msat # check this
if amt_to_forward >= expected_msat:
# not multi-part
return True
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
is_expired, is_accepted = self.get_mpp_status(payment_secret)
if not is_accepted and not is_expired:
bundle = self.get_payment_bundle(payment_hash)
payment_hashes = bundle or [payment_hash]
payment_secrets = [self.get_payment_secret(h) for h in bundle] if bundle else [payment_secret]
first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets])
if self.get_payment_status(payment_hash) == PR_PAID:
is_accepted = True
elif self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
preimage = self.get_preimage(payment_hash)
hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash)
if not preimage and hold_invoice_callback:
# for hold invoices, trigger callback
cb, timeout = hold_invoice_callback
if int(time.time()) < timeout:
cb(payment_hash)
else:
is_expired = True
elif bundle is not None:
is_accepted = all([bool(self.get_preimage(x)) for x in bundle])
else:
# trampoline forwarding needs this to return True
is_accepted = True
# set status for the bundle
if is_expired or is_accepted:
for x in payment_secrets:
if x in self.received_mpp_htlcs:
self.set_mpp_status(x, is_expired, is_accepted)
is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set()))
if self.get_payment_status(payment_hash) == PR_PAID:
# payment_status is persisted
is_accepted = True
is_expired = False
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc)
return True if is_accepted else (False if is_expired else None)
def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat):
# add new htlc to set
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set()))
assert expected_msat == _expected_msat
key = (short_channel_id, htlc)
if key not in htlc_set:
htlc_set.add(key)
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
def get_mpp_status(self, payment_secret):
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
return is_expired, is_accepted
def set_mpp_status(self, payment_secret, is_expired, is_accepted):
_is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
def is_mpp_amount_reached(self, payment_secret):
mpp = self.received_mpp_htlcs.get(payment_secret)
if not mpp:
return False
is_expired, is_accepted, _expected_msat, htlc_set = mpp
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
return total >= _expected_msat
def get_first_timestamp_of_mpp(self, payment_secret):
mpp = self.received_mpp_htlcs.get(payment_secret)
if not mpp:
return int(time.time())
is_expired, is_accepted, _expected_msat, htlc_set = mpp
return min([_htlc.timestamp for scid, _htlc in htlc_set])
def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc):
is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret]
if not is_accepted and not is_expired:
total = sum([_htlc.amount_msat for scid, _htlc in htlc_set])
first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set])
if self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
elif total == expected_msat:
is_accepted = True
if is_accepted or is_expired:
htlc_set.remove(key)
return
key = (short_channel_id, htlc)
htlc_set.remove(key)
if len(htlc_set) > 0:
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set
self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set
elif payment_secret in self.received_mpp_htlcs:
self.received_mpp_htlcs.pop(payment_secret)
return True if is_accepted else (False if is_expired else None)
def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash)

60
electrum/tests/test_lnpeer.py

@ -174,6 +174,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.stopping_soon = False
self.downstream_htlc_to_upstream_peer_map = {}
self.hold_invoice_callbacks = {}
self.payment_bundles = [] # lists of hashes. todo:persist
self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}")
@ -279,6 +280,16 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
get_mpp_status = LNWallet.get_mpp_status
set_mpp_status = LNWallet.set_mpp_status
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
bundle_payments = LNWallet.bundle_payments
get_payment_bundle = LNWallet.get_payment_bundle
class MockTransport:
def __init__(self, name):
self.queue = asyncio.Queue() # incoming messages
@ -727,7 +738,14 @@ class TestPeer(ElectrumTestCase):
await w.network.channel_db.stopped_event.wait()
w.network.channel_db = None
async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False):
async def _test_simple_payment(
self,
test_trampoline: bool,
test_hold_invoice=False,
test_hold_timeout=False,
test_bundle=False,
test_bundle_timeout=False
):
"""Alice pays Bob a single HTLC via direct channel."""
alice_channel, bob_channel = create_test_channels()
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel)
@ -746,12 +764,21 @@ class TestPeer(ElectrumTestCase):
def cb(payment_hash):
if not test_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_timeout else 60
timeout = 1 if test_hold_timeout else 60
w2.register_callback_for_hold_invoice(payment_hash, cb, timeout)
if test_bundle:
lnaddr2, pay_req2 = self.prepare_invoice(w2)
w2.bundle_payments([lnaddr.paymenthash, lnaddr2.paymenthash])
if test_trampoline:
await self._activate_trampoline(w1)
# declare bob as trampoline node
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
async def f():
if trampoline:
await self._activate_trampoline(w1)
async with OldTaskGroup() as group:
await group.spawn(p1._message_loop())
await group.spawn(p1.htlc_switch())
@ -761,22 +788,31 @@ class TestPeer(ElectrumTestCase):
invoice_features = lnaddr.get_features()
self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT))
await group.spawn(pay(lnaddr, pay_req))
# declare bob as trampoline node
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey),
}
if test_bundle and not test_bundle_timeout:
await group.spawn(pay(lnaddr2, pay_req2))
await f()
@needs_test_with_all_chacha20_implementations
async def test_simple_payment(self):
for trampoline in [False, True]:
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(test_trampoline=test_trampoline)
async def test_payment_bundle(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline)
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True)
async def test_payment_bundle_timeout(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True)
async def test_simple_payment_with_hold_invoice(self):
for trampoline in [False, True]:
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True)
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True)
async def test_simple_payment_with_hold_invoice_timing_out(self):
for trampoline in [False, True]:

Loading…
Cancel
Save