Browse Source

lnworker: add RecvMPPResolution with "FAILED" state

- add RecvMPPResolution enum for possible states of a pending incoming MPP,
  and use it in check_mpp_status
  - new state: "FAILED", to allow nicely failing back the whole MPP set
- key more things with payment_hash+payment_secret, for consistency
  (just payment_hash is insufficient for trampoline forwarding)
master
SomberNight 2 years ago
parent
commit
44bdd20ccc
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 19
      electrum/lnpeer.py
  2. 139
      electrum/lnworker.py
  3. 4
      electrum/tests/test_lnpeer.py

19
electrum/lnpeer.py

@ -1826,14 +1826,25 @@ class Peer(Logger):
log_fail_reason(f"'payment_secret' missing from onion")
raise exc_incorrect_or_unknown_pd
payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
if payment_status is None:
from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret_from_onion,
short_channel_id=chan.short_channel_id,
htlc=htlc,
expected_msat=total_msat,
)
if mpp_resolution == RecvMPPResolution.WAITING:
return None, None
elif payment_status is False:
elif mpp_resolution == RecvMPPResolution.EXPIRED:
log_fail_reason(f"MPP_TIMEOUT")
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
elif mpp_resolution == RecvMPPResolution.FAILED:
log_fail_reason(f"mpp_resolution is FAILED")
raise exc_incorrect_or_unknown_pd
elif mpp_resolution == RecvMPPResolution.ACCEPTED:
pass # continue
else:
assert payment_status is True
raise Exception(f"unexpected {mpp_resolution=}")
payment_hash = htlc.payment_hash

139
electrum/lnworker.py

@ -8,7 +8,8 @@ from decimal import Decimal
import random
import time
import operator
from enum import IntEnum
import enum
from enum import IntEnum, Enum
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable)
import threading
@ -167,9 +168,15 @@ class PaymentInfo(NamedTuple):
status: int
class RecvMPPResolution(Enum):
WAITING = enum.auto()
EXPIRED = enum.auto()
ACCEPTED = enum.auto()
FAILED = enum.auto()
class ReceivedMPPStatus(NamedTuple):
is_expired: bool
is_accepted: bool
resolution: RecvMPPResolution
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
@ -673,8 +680,8 @@ class LNWallet(LNWorker):
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
# detect inflight payments
@ -1418,13 +1425,14 @@ class LNWallet(LNWorker):
key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
payment_key = payment_hash + payment_secret
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and amount_msat != total_msat:
if payment_secret not in self.sent_buckets:
self.sent_buckets[payment_secret] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_secret]
if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
self.sent_buckets[payment_key] = amount_sent, amount_failed
if self.network.path_finder:
# add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
@ -1867,6 +1875,14 @@ class LNWallet(LNWorker):
def get_payment_secret(self, payment_hash):
return sha256(sha256(self.payment_secret_key) + payment_hash)
def _get_payment_key(self, payment_hash: bytes) -> bytes:
"""Return payment bucket key.
We bucket htlcs based on payment_hash+payment_secret. payment_secret is included
as it changes over a trampoline path (in the outer onion), and these paths can overlap.
"""
payment_secret = self.get_payment_secret(payment_hash)
return payment_hash + payment_secret
def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
@ -1923,103 +1939,101 @@ class LNWallet(LNWorker):
self.wallet.save_db()
def check_mpp_status(
self, payment_secret: bytes,
self, *,
payment_secret: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
) -> RecvMPPResolution:
payment_hash = htlc.payment_hash
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
is_expired, is_accepted = self.get_mpp_status(payment_secret)
if not is_accepted and not is_expired:
payment_key = payment_hash + payment_secret
self.update_mpp_with_received_htlc(
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
if mpp_resolution == RecvMPPResolution.WAITING:
bundle = self.get_payment_bundle(payment_hash)
if bundle:
payment_secrets = [self.get_payment_secret(h) for h in bundle]
if payment_secret not in payment_secrets:
payment_keys = [self._get_payment_key(h) for h in bundle]
if payment_key not in payment_keys:
# outer trampoline onion secret differs from inner onion
# the latter, not the former, might be part of a bundle
payment_secrets = [payment_secret]
payment_keys = [payment_key]
else:
payment_secrets = [payment_secret]
first_timestamp = min([self.get_first_timestamp_of_mpp(x) for x in payment_secrets])
payment_keys = [payment_key]
first_timestamp = min([self.get_first_timestamp_of_mpp(pkey) for pkey in payment_keys])
if self.get_payment_status(payment_hash) == PR_PAID:
is_accepted = True
mpp_resolution = RecvMPPResolution.ACCEPTED
elif self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
is_accepted = True
# try to time out pending HTLCs before shutting down
mpp_resolution = RecvMPPResolution.EXPIRED
elif all([self.is_mpp_amount_reached(pkey) for pkey in payment_keys]):
mpp_resolution = RecvMPPResolution.ACCEPTED
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True
mpp_resolution = RecvMPPResolution.EXPIRED
if is_accepted or is_expired:
for x in payment_secrets:
if x in self.received_mpp_htlcs:
self.set_mpp_status(x, is_expired, is_accepted)
if mpp_resolution != RecvMPPResolution.WAITING:
for pkey in payment_keys:
if pkey in self.received_mpp_htlcs:
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc)
return True if is_accepted else (False if is_expired else None)
self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc)
return mpp_resolution
def update_mpp_with_received_htlc(
self,
payment_secret: bytes,
short_channel_id: ShortChannelID,
*,
payment_key: bytes,
scid: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
):
# add new htlc to set
mpp_status = self.received_mpp_htlcs.get(payment_secret)
mpp_status = self.received_mpp_htlcs.get(payment_key)
if mpp_status is None:
mpp_status = ReceivedMPPStatus(
is_expired=False,
is_accepted=False,
resolution=RecvMPPResolution.WAITING,
expected_msat=expected_msat,
htlc_set=set(),
)
assert expected_msat == mpp_status.expected_msat
key = (short_channel_id, htlc)
if expected_msat != mpp_status.expected_msat:
self.logger.info(
f"marking received mpp as failed. inconsistent total_msats in bucket. {payment_key.hex()=}")
mpp_status = mpp_status._replace(resolution=RecvMPPResolution.FAILED)
key = (scid, htlc)
if key not in mpp_status.htlc_set:
mpp_status.htlc_set.add(key) # side-effecting htlc_set
self.received_mpp_htlcs[payment_secret] = mpp_status
def get_mpp_status(self, payment_secret: bytes) -> Tuple[bool, bool]:
mpp_status = self.received_mpp_htlcs[payment_secret]
return mpp_status.is_expired, mpp_status.is_accepted
self.received_mpp_htlcs[payment_key] = mpp_status
def set_mpp_status(self, payment_secret: bytes, is_expired: bool, is_accepted: bool):
mpp_status = self.received_mpp_htlcs[payment_secret]
self.received_mpp_htlcs[payment_secret] = mpp_status._replace(
is_expired=is_expired,
is_accepted=is_accepted,
)
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
mpp_status = self.received_mpp_htlcs[payment_key]
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
def is_mpp_amount_reached(self, payment_secret: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_secret)
def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status:
return False
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
return total >= mpp_status.expected_msat
def get_first_timestamp_of_mpp(self, payment_secret: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_secret)
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status:
return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
def maybe_cleanup_mpp_status(
self,
payment_secret: bytes,
payment_key: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
) -> None:
mpp_status = self.received_mpp_htlcs[payment_secret]
if not mpp_status.is_accepted and not mpp_status.is_expired:
mpp_status = self.received_mpp_htlcs[payment_key]
if mpp_status.resolution == RecvMPPResolution.WAITING:
return
key = (short_channel_id, htlc)
mpp_status.htlc_set.remove(key) # side-effecting htlc_set
if not mpp_status.htlc_set and payment_secret in self.received_mpp_htlcs:
self.received_mpp_htlcs.pop(payment_secret)
if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs:
self.received_mpp_htlcs.pop(payment_key)
def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash)
@ -2126,10 +2140,11 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline
if self.uses_trampoline() and payment_secret in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_secret]
payment_key = payment_hash + payment_secret
if self.uses_trampoline() and payment_key in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_secret] = amount_sent, amount_failed
self.sent_buckets[payment_key] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return

4
electrum/tests/test_lnpeer.py

@ -283,13 +283,13 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
add_payment_info_for_hold_invoice = LNWallet.add_payment_info_for_hold_invoice
update_mpp_with_received_htlc = LNWallet.update_mpp_with_received_htlc
get_mpp_status = LNWallet.get_mpp_status
set_mpp_status = LNWallet.set_mpp_status
set_mpp_resolution = LNWallet.set_mpp_resolution
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
bundle_payments = LNWallet.bundle_payments
get_payment_bundle = LNWallet.get_payment_bundle
_get_payment_key = LNWallet._get_payment_key
class MockTransport:

Loading…
Cancel
Save