From 44bdd20ccc40bb307cc3510d8741af7058e2c6e8 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Fri, 4 Aug 2023 13:27:05 +0000 Subject: [PATCH] lnworker: add RecvMPPResolution with "FAILED" state - add RecvMPPResolution enum for possible states of a pending incoming MPP, and use it in check_mpp_status - new state: "FAILED", to allow nicely failing back the whole MPP set - key more things with payment_hash+payment_secret, for consistency (just payment_hash is insufficient for trampoline forwarding) --- electrum/lnpeer.py | 19 ++++- electrum/lnworker.py | 139 +++++++++++++++++++--------------- electrum/tests/test_lnpeer.py | 4 +- 3 files changed, 94 insertions(+), 68 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 62833ac00..11edfe81a 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1826,14 +1826,25 @@ class Peer(Logger): log_fail_reason(f"'payment_secret' missing from onion") raise exc_incorrect_or_unknown_pd - payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat) - if payment_status is None: + from .lnworker import RecvMPPResolution + mpp_resolution = self.lnworker.check_mpp_status( + payment_secret=payment_secret_from_onion, + short_channel_id=chan.short_channel_id, + htlc=htlc, + expected_msat=total_msat, + ) + if mpp_resolution == RecvMPPResolution.WAITING: return None, None - elif payment_status is False: + elif mpp_resolution == RecvMPPResolution.EXPIRED: log_fail_reason(f"MPP_TIMEOUT") raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') + elif mpp_resolution == RecvMPPResolution.FAILED: + log_fail_reason(f"mpp_resolution is FAILED") + raise exc_incorrect_or_unknown_pd + elif mpp_resolution == RecvMPPResolution.ACCEPTED: + pass # continue else: - assert payment_status is True + raise Exception(f"unexpected {mpp_resolution=}") payment_hash = htlc.payment_hash diff --git a/electrum/lnworker.py b/electrum/lnworker.py index d2f907b6e..53ec4e3e3 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -8,7 +8,8 @@ from decimal import Decimal import random import time import operator -from enum import IntEnum +import enum +from enum import IntEnum, Enum from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable) import threading @@ -167,9 +168,15 @@ class PaymentInfo(NamedTuple): status: int +class RecvMPPResolution(Enum): + WAITING = enum.auto() + EXPIRED = enum.auto() + ACCEPTED = enum.auto() + FAILED = enum.auto() + + class ReceivedMPPStatus(NamedTuple): - is_expired: bool - is_accepted: bool + resolution: RecvMPPResolution expected_msat: int htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] @@ -673,8 +680,8 @@ class LNWallet(LNWorker): self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level - self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed) - self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus + self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) + self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) # detect inflight payments @@ -1418,13 +1425,14 @@ class LNWallet(LNWorker): key = (payment_hash, short_channel_id, htlc.htlc_id) self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route + payment_key = payment_hash + payment_secret # if we sent MPP to a trampoline, add item to sent_buckets if self.uses_trampoline() and amount_msat != total_msat: - if payment_secret not in self.sent_buckets: - self.sent_buckets[payment_secret] = (0, 0) - amount_sent, amount_failed = self.sent_buckets[payment_secret] + if payment_key not in self.sent_buckets: + self.sent_buckets[payment_key] = (0, 0) + amount_sent, amount_failed = self.sent_buckets[payment_key] amount_sent += amount_receiver_msat - self.sent_buckets[payment_secret] = amount_sent, amount_failed + self.sent_buckets[payment_key] = amount_sent, amount_failed if self.network.path_finder: # add inflight htlcs to liquidity hints self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) @@ -1867,6 +1875,14 @@ class LNWallet(LNWorker): def get_payment_secret(self, payment_hash): return sha256(sha256(self.payment_secret_key) + payment_hash) + def _get_payment_key(self, payment_hash: bytes) -> bytes: + """Return payment bucket key. + We bucket htlcs based on payment_hash+payment_secret. payment_secret is included + as it changes over a trampoline path (in the outer onion), and these paths can overlap. + """ + payment_secret = self.get_payment_secret(payment_hash) + return payment_hash + payment_secret + def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes: payment_preimage = os.urandom(32) payment_hash = sha256(payment_preimage) @@ -1923,103 +1939,101 @@ class LNWallet(LNWorker): self.wallet.save_db() def check_mpp_status( - self, payment_secret: bytes, + self, *, + payment_secret: bytes, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, expected_msat: int, - ) -> Optional[bool]: - """ return MPP status: True (accepted), False (expired) or None (waiting) - """ + ) -> RecvMPPResolution: payment_hash = htlc.payment_hash - self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat) - is_expired, is_accepted = self.get_mpp_status(payment_secret) - if not is_accepted and not is_expired: + payment_key = payment_hash + payment_secret + self.update_mpp_with_received_htlc( + payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) + mpp_resolution = self.received_mpp_htlcs[payment_key].resolution + if mpp_resolution == RecvMPPResolution.WAITING: bundle = self.get_payment_bundle(payment_hash) if bundle: - payment_secrets = [self.get_payment_secret(h) for h in bundle] - if payment_secret not in payment_secrets: + payment_keys = [self._get_payment_key(h) for h in bundle] + if payment_key not in payment_keys: # outer trampoline onion secret differs from inner onion # the latter, not the former, might be part of a bundle - payment_secrets = [payment_secret] + payment_keys = [payment_key] else: - payment_secrets = [payment_secret] - first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets]) + payment_keys = [payment_key] + first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys]) if self.get_payment_status(payment_hash) == PR_PAID: - is_accepted = True + mpp_resolution = RecvMPPResolution.ACCEPTED elif self.stopping_soon: - is_expired = True # try to time out pending HTLCs before shutting down - elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]): - is_accepted = True + # try to time out pending HTLCs before shutting down + mpp_resolution = RecvMPPResolution.EXPIRED + elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]): + mpp_resolution = RecvMPPResolution.ACCEPTED elif time.time() - first_timestamp > self.MPP_EXPIRY: - is_expired = True + mpp_resolution = RecvMPPResolution.EXPIRED - if is_accepted or is_expired: - for x in payment_secrets: - if x in self.received_mpp_htlcs: - self.set_mpp_status(x, is_expired, is_accepted) + if mpp_resolution != RecvMPPResolution.WAITING: + for pkey in payment_keys: + if pkey in self.received_mpp_htlcs: + self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) - self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc) - return True if is_accepted else (False if is_expired else None) + self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc) + return mpp_resolution def update_mpp_with_received_htlc( self, - payment_secret: bytes, - short_channel_id: ShortChannelID, + *, + payment_key: bytes, + scid: ShortChannelID, htlc: UpdateAddHtlc, expected_msat: int, ): # add new htlc to set - mpp_status = self.received_mpp_htlcs.get(payment_secret) + mpp_status = self.received_mpp_htlcs.get(payment_key) if mpp_status is None: mpp_status = ReceivedMPPStatus( - is_expired=False, - is_accepted=False, + resolution=RecvMPPResolution.WAITING, expected_msat=expected_msat, htlc_set=set(), ) - assert expected_msat == mpp_status.expected_msat - key = (short_channel_id, htlc) + if expected_msat != mpp_status.expected_msat: + self.logger.info( + f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}") + mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED) + key = (scid, htlc) if key not in mpp_status.htlc_set: mpp_status.htlc_set.add(key) # side-effecting htlc_set - self.received_mpp_htlcs[payment_secret] = mpp_status - - def get_mpp_status(self, payment_secret: bytes) -> Tuple[bool, bool]: - mpp_status = self.received_mpp_htlcs[payment_secret] - return mpp_status.is_expired, mpp_status.is_accepted + self.received_mpp_htlcs[payment_key] = mpp_status - def set_mpp_status(self, payment_secret: bytes, is_expired: bool, is_accepted: bool): - mpp_status = self.received_mpp_htlcs[payment_secret] - self.received_mpp_htlcs[payment_secret] = mpp_status._replace( - is_expired=is_expired, - is_accepted=is_accepted, - ) + def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): + mpp_status = self.received_mpp_htlcs[payment_key] + self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) - def is_mpp_amount_reached(self, payment_secret: bytes) -> bool: - mpp_status = self.received_mpp_htlcs.get(payment_secret) + def is_mpp_amount_reached(self, payment_key: bytes) -> bool: + mpp_status = self.received_mpp_htlcs.get(payment_key) if not mpp_status: return False total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) return total >= mpp_status.expected_msat - def get_first_timestamp_of_mpp(self, payment_secret: bytes) -> int: - mpp_status = self.received_mpp_htlcs.get(payment_secret) + def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: + mpp_status = self.received_mpp_htlcs.get(payment_key) if not mpp_status: return int(time.time()) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) def maybe_cleanup_mpp_status( self, - payment_secret: bytes, + payment_key: bytes, short_channel_id: ShortChannelID, htlc: UpdateAddHtlc, ) -> None: - mpp_status = self.received_mpp_htlcs[payment_secret] - if not mpp_status.is_accepted and not mpp_status.is_expired: + mpp_status = self.received_mpp_htlcs[payment_key] + if mpp_status.resolution == RecvMPPResolution.WAITING: return key = (short_channel_id, htlc) mpp_status.htlc_set.remove(key) # side-effecting htlc_set - if not mpp_status.htlc_set and payment_secret in self.received_mpp_htlcs: - self.received_mpp_htlcs.pop(payment_secret) + if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs: + self.received_mpp_htlcs.pop(payment_key) def get_payment_status(self, payment_hash: bytes) -> int: info = self.get_payment_info(payment_hash) @@ -2126,10 +2140,11 @@ class LNWallet(LNWorker): self.logger.info(f"htlc_failed {failure_message}") # check sent_buckets if we use trampoline - if self.uses_trampoline() and payment_secret in self.sent_buckets: - amount_sent, amount_failed = self.sent_buckets[payment_secret] + payment_key = payment_hash + payment_secret + if self.uses_trampoline() and payment_key in self.sent_buckets: + amount_sent, amount_failed = self.sent_buckets[payment_key] amount_failed += amount_receiver_msat - self.sent_buckets[payment_secret] = amount_sent, amount_failed + self.sent_buckets[payment_key] = amount_sent, amount_failed if amount_sent != amount_failed: self.logger.info('bucket still active...') return diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index ad072fe44..8c7ecbeee 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -283,13 +283,13 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc - get_mpp_status = LNWallet.get_mpp_status - set_mpp_status = LNWallet.set_mpp_status + set_mpp_resolution = LNWallet.set_mpp_resolution is_mpp_amount_reached = LNWallet.is_mpp_amount_reached get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status bundle_payments = LNWallet.bundle_payments get_payment_bundle = LNWallet.get_payment_bundle + _get_payment_key = LNWallet._get_payment_key class MockTransport: