Browse Source

Persist MPP resolution status in wallet file.

If we accept a MPP and we forward the payment (trampoline or swap),
we need to persist the payment accepted status, or we might wrongly
release htlcs on the next restart.

lnworker.received_mpp_htlcs used to be cleaned up in maybe_cleanup_forwarding,
which only applies to forwarded payments. However, since we now
persist this dict, we need to clean it up also in the case of
payments received by us. This part of maybe_cleanup_forwarding has
been migrated to lnworker.maybe_cleanup_mpp
master
ThomasV 2 years ago
parent
commit
5708f7b1c8
  1. 6
      electrum/lnpeer.py
  2. 73
      electrum/lnworker.py
  3. 8
      tests/test_lnpeer.py

6
electrum/lnpeer.py

@ -2750,6 +2750,7 @@ class Peer(Logger):
# return payment_key so this branch will not be executed again
return None, payment_key, None
elif preimage:
self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
return preimage, None, None
else:
# we are waiting for mpp consolidation or preimage
@ -2761,7 +2762,10 @@ class Peer(Logger):
preimage = self.lnworker.get_preimage(payment_hash)
error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key)
if error_bytes or error_reason or preimage:
self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc)
cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
is_htlc_key = ':' in payment_key
if is_htlc_key or payment_key in cleanup_keys:
self.lnworker.maybe_cleanup_forwarding(payment_key)
if error_bytes:
return None, None, error_bytes
if error_reason:

73
electrum/lnworker.py

@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager
from .channel_db import ChannelInfo, Policy
from .mpp_split import suggest_splits, SplitConfigRating
from .trampoline import create_trampoline_route_and_onion, is_legacy_relay
from .json_db import stored_in
if TYPE_CHECKING:
from .network import Network
@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple):
status: int
class RecvMPPResolution(Enum):
WAITING = enum.auto()
EXPIRED = enum.auto()
ACCEPTED = enum.auto()
FAILED = enum.auto()
# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
class RecvMPPResolution(IntEnum):
WAITING = 0
EXPIRED = 1
ACCEPTED = 2
FAILED = 3
class ReceivedMPPStatus(NamedTuple):
@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple):
expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
@stored_in('received_mpp_htlcs', tuple)
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list])
return ReceivedMPPStatus(
resolution=RecvMPPResolution(resolution),
expected_msat=expected_msat,
htlc_set=htlc_set)
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
@ -851,7 +861,7 @@ class LNWallet(LNWorker):
self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments
self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state
@ -2192,7 +2202,7 @@ class LNWallet(LNWorker):
payment_keys = [self._get_payment_key(x) for x in hash_list]
self.payment_bundles.append(payment_keys)
def get_payment_bundle(self, payment_key):
def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
for key_list in self.payment_bundles:
if payment_key in key_list:
return key_list
@ -2259,7 +2269,7 @@ class LNWallet(LNWorker):
payment_key = payment_hash + payment_secret
self.update_mpp_with_received_htlc(
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution
mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
# if still waiting, calc resolution now:
if mpp_resolution == RecvMPPResolution.WAITING:
bundle = self.get_payment_bundle(payment_key)
@ -2280,7 +2290,7 @@ class LNWallet(LNWorker):
# save resolution, if any.
if mpp_resolution != RecvMPPResolution.WAITING:
for pkey in payment_keys:
if pkey in self.received_mpp_htlcs:
if pkey.hex() in self.received_mpp_htlcs:
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
return mpp_resolution
@ -2294,7 +2304,7 @@ class LNWallet(LNWorker):
expected_msat: int,
):
# add new htlc to set
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if mpp_status is None:
mpp_status = ReceivedMPPStatus(
resolution=RecvMPPResolution.WAITING,
@ -2308,47 +2318,46 @@ class LNWallet(LNWorker):
key = (scid, htlc)
if key not in mpp_status.htlc_set:
mpp_status.htlc_set.add(key) # side-effecting htlc_set
self.received_mpp_htlcs[payment_key] = mpp_status
self.received_mpp_htlcs[payment_key.hex()] = mpp_status
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
mpp_status = self.received_mpp_htlcs[payment_key]
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution)
mpp_status = self.received_mpp_htlcs[payment_key.hex()]
self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status:
return False
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
return total >= mpp_status.expected_msat
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_key)
mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status:
return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
def maybe_cleanup_forwarding(
def maybe_cleanup_mpp(
self,
payment_key_hex: str,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
) -> None:
is_htlc_key = ':' in payment_key_hex
if not is_htlc_key:
payment_key = bytes.fromhex(payment_key_hex)
mpp_status = self.received_mpp_htlcs.get(payment_key)
if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING:
# After restart, self.received_mpp_htlcs needs to be reconstructed
self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready')
return
htlc_key = (short_channel_id, htlc)
) -> Sequence[str]:
htlc_key = (short_channel_id, htlc)
cleanup_keys = []
for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
if htlc_key not in mpp_status.htlc_set:
continue
assert mpp_status.resolution != RecvMPPResolution.WAITING
self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set
if mpp_status.htlc_set:
return
self.logger.info('cleaning up mpp')
self.received_mpp_htlcs.pop(payment_key)
if len(mpp_status.htlc_set) == 0:
self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
self.received_mpp_htlcs.pop(payment_key_hex)
cleanup_keys.append(payment_key_hex)
return cleanup_keys
def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
self.active_forwardings.pop(payment_key_hex, None)
self.forwarding_failures.pop(payment_key_hex, None)

8
tests/test_lnpeer.py

@ -316,6 +316,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw
maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp
class MockTransport:
@ -1741,6 +1742,7 @@ class TestPeerForwarding(TestPeer):
):
alice_w = graph.workers['alice']
bob_w = graph.workers['bob']
carol_w = graph.workers['carol']
dave_w = graph.workers['dave']
if mpp_invoice:
dave_w.features |= LnFeatures.BASIC_MPP_OPT
@ -1762,6 +1764,12 @@ class TestPeerForwarding(TestPeer):
await asyncio.sleep(2)
if result:
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
# check mpp is cleaned up
async with OldTaskGroup() as g:
for peer in peers:
await g.spawn(peer.wait_one_htlc_switch_iteration())
for peer in peers:
self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0)
raise PaymentDone()
elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentTimeout()

Loading…
Cancel
Save