From 5708f7b1c803cb39e567207d7454d483a584861a Mon Sep 17 00:00:00 2001 From: ThomasV Date: Fri, 7 Jun 2024 11:45:21 +0200 Subject: [PATCH] Persist MPP resolution status in wallet file. If we accept a MPP and we forward the payment (trampoline or swap), we need to persist the payment accepted status, or we might wrongly release htlcs on the next restart. lnworker.received_mpp_htlcs used to be cleaned up in maybe_cleanup_forwarding, which only applies to forwarded payments. However, since we now persist this dict, we need to clean it up also in the case of payments received by us. This part of maybe_cleanup_forwarding has been migrated to lnworker.maybe_cleanup_mpp --- electrum/lnpeer.py | 6 +++- electrum/lnworker.py | 73 +++++++++++++++++++++++++------------------- tests/test_lnpeer.py | 8 +++++ 3 files changed, 54 insertions(+), 33 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index b84a993a6..ba8d3e5dd 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2750,6 +2750,7 @@ class Peer(Logger): # return payment_key so this branch will not be executed again return None, payment_key, None elif preimage: + self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) return preimage, None, None else: # we are waiting for mpp consolidation or preimage @@ -2761,7 +2762,10 @@ class Peer(Logger): preimage = self.lnworker.get_preimage(payment_hash) error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) if error_bytes or error_reason or preimage: - self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc) + cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) + is_htlc_key = ':' in payment_key + if is_htlc_key or payment_key in cleanup_keys: + self.lnworker.maybe_cleanup_forwarding(payment_key) if error_bytes: return None, None, error_bytes if error_reason: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b0a895f43..750746ff2 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager from .channel_db import ChannelInfo, Policy from .mpp_split import suggest_splits, SplitConfigRating from .trampoline import create_trampoline_route_and_onion, is_legacy_relay +from .json_db import stored_in if TYPE_CHECKING: from .network import Network @@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple): status: int -class RecvMPPResolution(Enum): - WAITING = enum.auto() - EXPIRED = enum.auto() - ACCEPTED = enum.auto() - FAILED = enum.auto() +# Note: these states are persisted in the wallet file. +# Do not modify them without performing a wallet db upgrade +class RecvMPPResolution(IntEnum): + WAITING = 0 + EXPIRED = 1 + ACCEPTED = 2 + FAILED = 3 class ReceivedMPPStatus(NamedTuple): @@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple): expected_msat: int htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + @stored_in('received_mpp_htlcs', tuple) + def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': + htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list]) + return ReceivedMPPStatus( + resolution=RecvMPPResolution(resolution), + expected_msat=expected_msat, + htlc_set=htlc_set) SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -851,7 +861,7 @@ class LNWallet(LNWorker): self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] - self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus + self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus # detect inflight payments self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state @@ -2192,7 +2202,7 @@ class LNWallet(LNWorker): payment_keys = [self._get_payment_key(x) for x in hash_list] self.payment_bundles.append(payment_keys) - def get_payment_bundle(self, payment_key): + def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: for key_list in self.payment_bundles: if payment_key in key_list: return key_list @@ -2259,7 +2269,7 @@ class LNWallet(LNWorker): payment_key = payment_hash + payment_secret self.update_mpp_with_received_htlc( payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key].resolution + mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution # if still waiting, calc resolution now: if mpp_resolution == RecvMPPResolution.WAITING: bundle = self.get_payment_bundle(payment_key) @@ -2280,7 +2290,7 @@ class LNWallet(LNWorker): # save resolution, if any. if mpp_resolution != RecvMPPResolution.WAITING: for pkey in payment_keys: - if pkey in self.received_mpp_htlcs: + if pkey.hex() in self.received_mpp_htlcs: self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) return mpp_resolution @@ -2294,7 +2304,7 @@ class LNWallet(LNWorker): expected_msat: int, ): # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if mpp_status is None: mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, @@ -2308,47 +2318,46 @@ class LNWallet(LNWorker): key = (scid, htlc) if key not in mpp_status.htlc_set: mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key] = mpp_status + self.received_mpp_htlcs[payment_key.hex()] = mpp_status def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key] - self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) + mpp_status = self.received_mpp_htlcs[payment_key.hex()] + self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}') + self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return False total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) return total >= mpp_status.expected_msat def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return int(time.time()) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) - def maybe_cleanup_forwarding( + def maybe_cleanup_mpp( self, - payment_key_hex: str, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, - ) -> None: - - is_htlc_key = ':' in payment_key_hex - if not is_htlc_key: - payment_key = bytes.fromhex(payment_key_hex) - mpp_status = self.received_mpp_htlcs.get(payment_key) - if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING: - # After restart, self.received_mpp_htlcs needs to be reconstructed - self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready') - return - htlc_key = (short_channel_id, htlc) + ) -> Sequence[str]: + htlc_key = (short_channel_id, htlc) + cleanup_keys = [] + for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): + if htlc_key not in mpp_status.htlc_set: + continue + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if mpp_status.htlc_set: - return - self.logger.info('cleaning up mpp') - self.received_mpp_htlcs.pop(payment_key) + if len(mpp_status.htlc_set) == 0: + self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') + self.received_mpp_htlcs.pop(payment_key_hex) + cleanup_keys.append(payment_key_hex) + return cleanup_keys + def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index dfca4b990..c674e8897 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -316,6 +316,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw + maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp class MockTransport: @@ -1741,6 +1742,7 @@ class TestPeerForwarding(TestPeer): ): alice_w = graph.workers['alice'] bob_w = graph.workers['bob'] + carol_w = graph.workers['carol'] dave_w = graph.workers['dave'] if mpp_invoice: dave_w.features |= LnFeatures.BASIC_MPP_OPT @@ -1762,6 +1764,12 @@ class TestPeerForwarding(TestPeer): await asyncio.sleep(2) if result: self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + # check mpp is cleaned up + async with OldTaskGroup() as g: + for peer in peers: + await g.spawn(peer.wait_one_htlc_switch_iteration()) + for peer in peers: + self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0) raise PaymentDone() elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT: raise PaymentTimeout()