diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index b84a993a6..ba8d3e5dd 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -2750,6 +2750,7 @@ class Peer(Logger): # return payment_key so this branch will not be executed again return None, payment_key, None elif preimage: + self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) return preimage, None, None else: # we are waiting for mpp consolidation or preimage @@ -2761,7 +2762,10 @@ class Peer(Logger): preimage = self.lnworker.get_preimage(payment_hash) error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) if error_bytes or error_reason or preimage: - self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc) + cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc) + is_htlc_key = ':' in payment_key + if is_htlc_key or payment_key in cleanup_keys: + self.lnworker.maybe_cleanup_forwarding(payment_key) if error_bytes: return None, None, error_bytes if error_reason: diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b0a895f43..750746ff2 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager from .channel_db import ChannelInfo, Policy from .mpp_split import suggest_splits, SplitConfigRating from .trampoline import create_trampoline_route_and_onion, is_legacy_relay +from .json_db import stored_in if TYPE_CHECKING: from .network import Network @@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple): status: int -class RecvMPPResolution(Enum): - WAITING = enum.auto() - EXPIRED = enum.auto() - ACCEPTED = enum.auto() - FAILED = enum.auto() +# Note: these states are persisted in the wallet file. +# Do not modify them without performing a wallet db upgrade +class RecvMPPResolution(IntEnum): + WAITING = 0 + EXPIRED = 1 + ACCEPTED = 2 + FAILED = 3 class ReceivedMPPStatus(NamedTuple): @@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple): expected_msat: int htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + @stored_in('received_mpp_htlcs', tuple) + def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus': + htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list]) + return ReceivedMPPStatus( + resolution=RecvMPPResolution(resolution), + expected_msat=expected_msat, + htlc_set=htlc_set) SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id @@ -851,7 +861,7 @@ class LNWallet(LNWorker): self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] - self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus + self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus # detect inflight payments self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state @@ -2192,7 +2202,7 @@ class LNWallet(LNWorker): payment_keys = [self._get_payment_key(x) for x in hash_list] self.payment_bundles.append(payment_keys) - def get_payment_bundle(self, payment_key): + def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]: for key_list in self.payment_bundles: if payment_key in key_list: return key_list @@ -2259,7 +2269,7 @@ class LNWallet(LNWorker): payment_key = payment_hash + payment_secret self.update_mpp_with_received_htlc( payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) - mpp_resolution = self.received_mpp_htlcs[payment_key].resolution + mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution # if still waiting, calc resolution now: if mpp_resolution == RecvMPPResolution.WAITING: bundle = self.get_payment_bundle(payment_key) @@ -2280,7 +2290,7 @@ class LNWallet(LNWorker): # save resolution, if any. if mpp_resolution != RecvMPPResolution.WAITING: for pkey in payment_keys: - if pkey in self.received_mpp_htlcs: + if pkey.hex() in self.received_mpp_htlcs: self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) return mpp_resolution @@ -2294,7 +2304,7 @@ class LNWallet(LNWorker): expected_msat: int, ): # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if mpp_status is None: mpp_status = ReceivedMPPStatus( resolution=RecvMPPResolution.WAITING, @@ -2308,47 +2318,46 @@ class LNWallet(LNWorker): key = (scid, htlc) if key not in mpp_status.htlc_set: mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_key] = mpp_status + self.received_mpp_htlcs[payment_key.hex()] = mpp_status def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): - mpp_status = self.received_mpp_htlcs[payment_key] - self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) + mpp_status = self.received_mpp_htlcs[payment_key.hex()] + self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}') + self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution) def is_mpp_amount_reached(self, payment_key: bytes) -> bool: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return False total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) return total >= mpp_status.expected_msat def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_key) + mpp_status = self.received_mpp_htlcs.get(payment_key.hex()) if not mpp_status: return int(time.time()) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) - def maybe_cleanup_forwarding( + def maybe_cleanup_mpp( self, - payment_key_hex: str, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, - ) -> None: - - is_htlc_key = ':' in payment_key_hex - if not is_htlc_key: - payment_key = bytes.fromhex(payment_key_hex) - mpp_status = self.received_mpp_htlcs.get(payment_key) - if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING: - # After restart, self.received_mpp_htlcs needs to be reconstructed - self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready') - return - htlc_key = (short_channel_id, htlc) + ) -> Sequence[str]: + htlc_key = (short_channel_id, htlc) + cleanup_keys = [] + for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()): + if htlc_key not in mpp_status.htlc_set: + continue + assert mpp_status.resolution != RecvMPPResolution.WAITING + self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}') mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set - if mpp_status.htlc_set: - return - self.logger.info('cleaning up mpp') - self.received_mpp_htlcs.pop(payment_key) + if len(mpp_status.htlc_set) == 0: + self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}') + self.received_mpp_htlcs.pop(payment_key_hex) + cleanup_keys.append(payment_key_hex) + return cleanup_keys + def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None: self.active_forwardings.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None) diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index dfca4b990..c674e8897 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -316,6 +316,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw + maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp class MockTransport: @@ -1741,6 +1742,7 @@ class TestPeerForwarding(TestPeer): ): alice_w = graph.workers['alice'] bob_w = graph.workers['bob'] + carol_w = graph.workers['carol'] dave_w = graph.workers['dave'] if mpp_invoice: dave_w.features |= LnFeatures.BASIC_MPP_OPT @@ -1762,6 +1764,12 @@ class TestPeerForwarding(TestPeer): await asyncio.sleep(2) if result: self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) + # check mpp is cleaned up + async with OldTaskGroup() as g: + for peer in peers: + await g.spawn(peer.wait_one_htlc_switch_iteration()) + for peer in peers: + self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0) raise PaymentDone() elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT: raise PaymentTimeout()