diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b76a50680..15d6bed85 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -677,6 +677,8 @@ class LNWallet(LNWorker): # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout + self.payment_bundles = [] # lists of hashes. todo:persist + def has_deterministic_node_id(self) -> bool: return bool(self.db.get('lightning_xprv')) @@ -1862,6 +1864,14 @@ class LNWallet(LNWorker): self.wallet.save_db() return payment_hash + def bundle_payments(self, hash_list): + self.payment_bundles.append(hash_list) + + def get_payment_bundle(self, payment_hash): + for hash_list in self.payment_bundles: + if payment_hash in hash_list: + return hash_list + def save_preimage(self, payment_hash: bytes, preimage: bytes, *, write_to_disk: bool = True): assert sha256(preimage) == payment_hash self.preimages[payment_hash.hex()] = preimage.hex() @@ -1901,45 +1911,87 @@ class LNWallet(LNWorker): """ return MPP status: True (accepted), False (expired) or None (waiting) """ payment_hash = htlc.payment_hash - preimage = self.get_preimage(payment_hash) - callback = self.hold_invoice_callbacks.get(payment_hash) - if not preimage and callback: - cb, timeout = callback - if int(time.time()) < timeout: - cb(payment_hash) - return None - else: - return False - amt_to_forward = htlc.amount_msat # check this - if amt_to_forward >= expected_msat: - # not multi-part - return True + self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat) + is_expired, is_accepted = self.get_mpp_status(payment_secret) + if not is_accepted and not is_expired: + bundle = self.get_payment_bundle(payment_hash) + payment_hashes = bundle or [payment_hash] + payment_secrets = [self.get_payment_secret(h) for h in bundle] if bundle else [payment_secret] + first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets]) + if self.get_payment_status(payment_hash) == PR_PAID: + is_accepted = True + elif self.stopping_soon: + is_expired = True # try to time out pending HTLCs before shutting down + elif time.time() - first_timestamp > self.MPP_EXPIRY: + is_expired = True + elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]): + preimage = self.get_preimage(payment_hash) + hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash) + if not preimage and hold_invoice_callback: + # for hold invoices, trigger callback + cb, timeout = hold_invoice_callback + if int(time.time()) < timeout: + cb(payment_hash) + else: + is_expired = True + elif bundle is not None: + is_accepted = all([bool(self.get_preimage(x)) for x in bundle]) + else: + # trampoline forwarding needs this to return True + is_accepted = True + + # set status for the bundle + if is_expired or is_accepted: + for x in payment_secrets: + if x in self.received_mpp_htlcs: + self.set_mpp_status(x, is_expired, is_accepted) - is_expired, is_accepted, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, set())) - if self.get_payment_status(payment_hash) == PR_PAID: - # payment_status is persisted - is_accepted = True - is_expired = False + self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc) + return True if is_accepted else (False if is_expired else None) + + def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat): + # add new htlc to set + is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set())) + assert expected_msat == _expected_msat key = (short_channel_id, htlc) if key not in htlc_set: htlc_set.add(key) + self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set + + def get_mpp_status(self, payment_secret): + is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] + return is_expired, is_accepted + + def set_mpp_status(self, payment_secret, is_expired, is_accepted): + _is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] + self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set + + def is_mpp_amount_reached(self, payment_secret): + mpp = self.received_mpp_htlcs.get(payment_secret) + if not mpp: + return False + is_expired, is_accepted, _expected_msat, htlc_set = mpp + total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) + return total >= _expected_msat + + def get_first_timestamp_of_mpp(self, payment_secret): + mpp = self.received_mpp_htlcs.get(payment_secret) + if not mpp: + return int(time.time()) + is_expired, is_accepted, _expected_msat, htlc_set = mpp + return min([_htlc.timestamp for scid, _htlc in htlc_set]) + + def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc): + is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] if not is_accepted and not is_expired: - total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) - first_timestamp = min([_htlc.timestamp for scid, _htlc in htlc_set]) - if self.stopping_soon: - is_expired = True # try to time out pending HTLCs before shutting down - elif time.time() - first_timestamp > self.MPP_EXPIRY: - is_expired = True - elif total == expected_msat: - is_accepted = True - if is_accepted or is_expired: - htlc_set.remove(key) + return + key = (short_channel_id, htlc) + htlc_set.remove(key) if len(htlc_set) > 0: - self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, htlc_set + self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set elif payment_secret in self.received_mpp_htlcs: self.received_mpp_htlcs.pop(payment_secret) - return True if is_accepted else (False if is_expired else None) def get_payment_status(self, payment_hash: bytes) -> int: info = self.get_payment_info(payment_hash) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 6f9bd7d66..964899b19 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -174,6 +174,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.stopping_soon = False self.downstream_htlc_to_upstream_peer_map = {} self.hold_invoice_callbacks = {} + self.payment_bundles = [] # lists of hashes. todo:persist self.logger.info(f"created LNWallet[{name}] with nodeID={local_keypair.pubkey.hex()}") @@ -279,6 +280,16 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): register_callback_for_hold_invoice = LNWallet.register_callback_for_hold_invoice add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice + update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc + get_mpp_status = LNWallet.get_mpp_status + set_mpp_status = LNWallet.set_mpp_status + is_mpp_amount_reached = LNWallet.is_mpp_amount_reached + get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp + maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status + bundle_payments = LNWallet.bundle_payments + get_payment_bundle = LNWallet.get_payment_bundle + + class MockTransport: def __init__(self, name): self.queue = asyncio.Queue() # incoming messages @@ -727,7 +738,14 @@ class TestPeer(ElectrumTestCase): await w.network.channel_db.stopped_event.wait() w.network.channel_db = None - async def _test_simple_payment(self, trampoline: bool, test_hold_invoice=False, test_timeout=False): + async def _test_simple_payment( + self, + test_trampoline: bool, + test_hold_invoice=False, + test_hold_timeout=False, + test_bundle=False, + test_bundle_timeout=False + ): """Alice pays Bob a single HTLC via direct channel.""" alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) @@ -746,12 +764,21 @@ class TestPeer(ElectrumTestCase): def cb(payment_hash): if not test_timeout: w2.save_preimage(payment_hash, preimage) - timeout = 1 if test_timeout else 60 + timeout = 1 if test_hold_timeout else 60 w2.register_callback_for_hold_invoice(payment_hash, cb, timeout) + if test_bundle: + lnaddr2, pay_req2 = self.prepare_invoice(w2) + w2.bundle_payments([lnaddr.paymenthash, lnaddr2.paymenthash]) + + if test_trampoline: + await self._activate_trampoline(w1) + # declare bob as trampoline node + electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { + 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), + } + async def f(): - if trampoline: - await self._activate_trampoline(w1) async with OldTaskGroup() as group: await group.spawn(p1._message_loop()) await group.spawn(p1.htlc_switch()) @@ -761,22 +788,31 @@ class TestPeer(ElectrumTestCase): invoice_features = lnaddr.get_features() self.assertFalse(invoice_features.supports(LnFeatures.BASIC_MPP_OPT)) await group.spawn(pay(lnaddr, pay_req)) - # declare bob as trampoline node - electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { - 'bob': LNPeerAddr(host="127.0.0.1", port=9735, pubkey=w2.node_keypair.pubkey), - } + if test_bundle and not test_bundle_timeout: + await group.spawn(pay(lnaddr2, pay_req2)) + await f() @needs_test_with_all_chacha20_implementations async def test_simple_payment(self): - for trampoline in [False, True]: + for test_trampoline in [False, True]: + with self.assertRaises(PaymentDone): + await self._test_simple_payment(test_trampoline=test_trampoline) + + async def test_payment_bundle(self): + for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): - await self._test_simple_payment(trampoline=trampoline) + await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True) + + async def test_payment_bundle_timeout(self): + for test_trampoline in [False, True]: + with self.assertRaises(PaymentFailure): + await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True) async def test_simple_payment_with_hold_invoice(self): - for trampoline in [False, True]: + for test_trampoline in [False, True]: with self.assertRaises(PaymentDone): - await self._test_simple_payment(trampoline=trampoline, test_hold_invoice=True) + await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True) async def test_simple_payment_with_hold_invoice_timing_out(self): for trampoline in [False, True]: