Browse Source

Persist MPP resolution status in wallet file.

If we accept a MPP and we forward the payment (trampoline or swap),
we need to persist the payment accepted status, or we might wrongly
release htlcs on the next restart.

lnworker.received_mpp_htlcs used to be cleaned up in maybe_cleanup_forwarding,
which only applies to forwarded payments. However, since we now
persist this dict, we need to clean it up also in the case of
payments received by us. This part of maybe_cleanup_forwarding has
been migrated to lnworker.maybe_cleanup_mpp
master
ThomasV 2 years ago
parent
commit
5708f7b1c8
  1. 6
      electrum/lnpeer.py
  2. 73
      electrum/lnworker.py
  3. 8
      tests/test_lnpeer.py

6
electrum/lnpeer.py

@ -2750,6 +2750,7 @@ class Peer(Logger):
# return payment_key so this branch will not be executed again # return payment_key so this branch will not be executed again
return None, payment_key, None return None, payment_key, None
elif preimage: elif preimage:
self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
return preimage, None, None return preimage, None, None
else: else:
# we are waiting for mpp consolidation or preimage # we are waiting for mpp consolidation or preimage
@ -2761,7 +2762,10 @@ class Peer(Logger):
preimage = self.lnworker.get_preimage(payment_hash) preimage = self.lnworker.get_preimage(payment_hash)
error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key) error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key)
if error_bytes or error_reason or preimage: if error_bytes or error_reason or preimage:
self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc) cleanup_keys = self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
is_htlc_key = ':' in payment_key
if is_htlc_key or payment_key in cleanup_keys:
self.lnworker.maybe_cleanup_forwarding(payment_key)
if error_bytes: if error_bytes:
return None, None, error_bytes return None, None, error_bytes
if error_reason: if error_reason:

73
electrum/lnworker.py

@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager
from .channel_db import ChannelInfo, Policy from .channel_db import ChannelInfo, Policy
from .mpp_split import suggest_splits, SplitConfigRating from .mpp_split import suggest_splits, SplitConfigRating
from .trampoline import create_trampoline_route_and_onion, is_legacy_relay from .trampoline import create_trampoline_route_and_onion, is_legacy_relay
from .json_db import stored_in
if TYPE_CHECKING: if TYPE_CHECKING:
from .network import Network from .network import Network
@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple):
status: int status: int
class RecvMPPResolution(Enum): # Note: these states are persisted in the wallet file.
WAITING = enum.auto() # Do not modify them without performing a wallet db upgrade
EXPIRED = enum.auto() class RecvMPPResolution(IntEnum):
ACCEPTED = enum.auto() WAITING = 0
FAILED = enum.auto() EXPIRED = 1
ACCEPTED = 2
FAILED = 3
class ReceivedMPPStatus(NamedTuple): class ReceivedMPPStatus(NamedTuple):
@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple):
expected_msat: int expected_msat: int
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
@stored_in('received_mpp_htlcs', tuple)
def from_tuple(resolution, expected_msat, htlc_list) -> 'ReceivedMPPStatus':
htlc_set = set([(ShortChannelID(bytes.fromhex(scid)), UpdateAddHtlc.from_tuple(*x)) for (scid,x) in htlc_list])
return ReceivedMPPStatus(
resolution=RecvMPPResolution(resolution),
expected_msat=expected_msat,
htlc_set=htlc_set)
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
@ -851,7 +861,7 @@ class LNWallet(LNWorker):
self._paysessions = dict() # type: Dict[bytes, PaySession] self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments # detect inflight payments
self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state
@ -2192,7 +2202,7 @@ class LNWallet(LNWorker):
payment_keys = [self._get_payment_key(x) for x in hash_list] payment_keys = [self._get_payment_key(x) for x in hash_list]
self.payment_bundles.append(payment_keys) self.payment_bundles.append(payment_keys)
def get_payment_bundle(self, payment_key): def get_payment_bundle(self, payment_key: bytes) -> Sequence[bytes]:
for key_list in self.payment_bundles: for key_list in self.payment_bundles:
if payment_key in key_list: if payment_key in key_list:
return key_list return key_list
@ -2259,7 +2269,7 @@ class LNWallet(LNWorker):
payment_key = payment_hash + payment_secret payment_key = payment_hash + payment_secret
self.update_mpp_with_received_htlc( self.update_mpp_with_received_htlc(
payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat) payment_key=payment_key, scid=short_channel_id, htlc=htlc, expected_msat=expected_msat)
mpp_resolution = self.received_mpp_htlcs[payment_key].resolution mpp_resolution = self.received_mpp_htlcs[payment_key.hex()].resolution
# if still waiting, calc resolution now: # if still waiting, calc resolution now:
if mpp_resolution == RecvMPPResolution.WAITING: if mpp_resolution == RecvMPPResolution.WAITING:
bundle = self.get_payment_bundle(payment_key) bundle = self.get_payment_bundle(payment_key)
@ -2280,7 +2290,7 @@ class LNWallet(LNWorker):
# save resolution, if any. # save resolution, if any.
if mpp_resolution != RecvMPPResolution.WAITING: if mpp_resolution != RecvMPPResolution.WAITING:
for pkey in payment_keys: for pkey in payment_keys:
if pkey in self.received_mpp_htlcs: if pkey.hex() in self.received_mpp_htlcs:
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
return mpp_resolution return mpp_resolution
@ -2294,7 +2304,7 @@ class LNWallet(LNWorker):
expected_msat: int, expected_msat: int,
): ):
# add new htlc to set # add new htlc to set
mpp_status = self.received_mpp_htlcs.get(payment_key) mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if mpp_status is None: if mpp_status is None:
mpp_status = ReceivedMPPStatus( mpp_status = ReceivedMPPStatus(
resolution=RecvMPPResolution.WAITING, resolution=RecvMPPResolution.WAITING,
@ -2308,47 +2318,46 @@ class LNWallet(LNWorker):
key = (scid, htlc) key = (scid, htlc)
if key not in mpp_status.htlc_set: if key not in mpp_status.htlc_set:
mpp_status.htlc_set.add(key) # side-effecting htlc_set mpp_status.htlc_set.add(key) # side-effecting htlc_set
self.received_mpp_htlcs[payment_key] = mpp_status self.received_mpp_htlcs[payment_key.hex()] = mpp_status
def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution): def set_mpp_resolution(self, *, payment_key: bytes, resolution: RecvMPPResolution):
mpp_status = self.received_mpp_htlcs[payment_key] mpp_status = self.received_mpp_htlcs[payment_key.hex()]
self.received_mpp_htlcs[payment_key] = mpp_status._replace(resolution=resolution) self.logger.info(f'set_mpp_resolution {resolution.name} {len(mpp_status.htlc_set)} {payment_key.hex()}')
self.received_mpp_htlcs[payment_key.hex()] = mpp_status._replace(resolution=resolution)
def is_mpp_amount_reached(self, payment_key: bytes) -> bool: def is_mpp_amount_reached(self, payment_key: bytes) -> bool:
mpp_status = self.received_mpp_htlcs.get(payment_key) mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status: if not mpp_status:
return False return False
total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set])
return total >= mpp_status.expected_msat return total >= mpp_status.expected_msat
def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int: def get_first_timestamp_of_mpp(self, payment_key: bytes) -> int:
mpp_status = self.received_mpp_htlcs.get(payment_key) mpp_status = self.received_mpp_htlcs.get(payment_key.hex())
if not mpp_status: if not mpp_status:
return int(time.time()) return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
def maybe_cleanup_forwarding( def maybe_cleanup_mpp(
self, self,
payment_key_hex: str,
short_channel_id: ShortChannelID, short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
) -> None: ) -> Sequence[str]:
htlc_key = (short_channel_id, htlc)
is_htlc_key = ':' in payment_key_hex cleanup_keys = []
if not is_htlc_key: for payment_key_hex, mpp_status in list(self.received_mpp_htlcs.items()):
payment_key = bytes.fromhex(payment_key_hex) if htlc_key not in mpp_status.htlc_set:
mpp_status = self.received_mpp_htlcs.get(payment_key) continue
if not mpp_status or mpp_status.resolution == RecvMPPResolution.WAITING: assert mpp_status.resolution != RecvMPPResolution.WAITING
# After restart, self.received_mpp_htlcs needs to be reconstructed self.logger.info(f'maybe_cleanup_mpp: removing htlc of MPP {payment_key_hex}')
self.logger.info(f'maybe_cleanup_forwarding: mpp_status not ready')
return
htlc_key = (short_channel_id, htlc)
mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set
if mpp_status.htlc_set: if len(mpp_status.htlc_set) == 0:
return self.logger.info(f'maybe_cleanup_mpp: removing mpp {payment_key_hex}')
self.logger.info('cleaning up mpp') self.received_mpp_htlcs.pop(payment_key_hex)
self.received_mpp_htlcs.pop(payment_key) cleanup_keys.append(payment_key_hex)
return cleanup_keys
def maybe_cleanup_forwarding(self, payment_key_hex: str) -> None:
self.active_forwardings.pop(payment_key_hex, None) self.active_forwardings.pop(payment_key_hex, None)
self.forwarding_failures.pop(payment_key_hex, None) self.forwarding_failures.pop(payment_key_hex, None)

8
tests/test_lnpeer.py

@ -316,6 +316,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw current_target_feerate_per_kw = LNWallet.current_target_feerate_per_kw
current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw current_low_feerate_per_kw = LNWallet.current_low_feerate_per_kw
maybe_cleanup_mpp = LNWallet.maybe_cleanup_mpp
class MockTransport: class MockTransport:
@ -1741,6 +1742,7 @@ class TestPeerForwarding(TestPeer):
): ):
alice_w = graph.workers['alice'] alice_w = graph.workers['alice']
bob_w = graph.workers['bob'] bob_w = graph.workers['bob']
carol_w = graph.workers['carol']
dave_w = graph.workers['dave'] dave_w = graph.workers['dave']
if mpp_invoice: if mpp_invoice:
dave_w.features |= LnFeatures.BASIC_MPP_OPT dave_w.features |= LnFeatures.BASIC_MPP_OPT
@ -1762,6 +1764,12 @@ class TestPeerForwarding(TestPeer):
await asyncio.sleep(2) await asyncio.sleep(2)
if result: if result:
self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, dave_w.get_payment_status(lnaddr.paymenthash))
# check mpp is cleaned up
async with OldTaskGroup() as g:
for peer in peers:
await g.spawn(peer.wait_one_htlc_switch_iteration())
for peer in peers:
self.assertEqual(len(peer.lnworker.received_mpp_htlcs), 0)
raise PaymentDone() raise PaymentDone()
elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT: elif len(log) == 1 and log[0].failure_msg.code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentTimeout() raise PaymentTimeout()

Loading…
Cancel
Save