diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 468596ace..c8b62d988 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -10,7 +10,7 @@ import time import operator from enum import IntEnum from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, - NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict) + NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict, Callable) import threading import socket import aiohttp @@ -167,6 +167,13 @@ class PaymentInfo(NamedTuple): status: int +class ReceivedMPPStatus(NamedTuple): + is_expired: bool + is_accepted: bool + expected_msat: int + htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] + + class ErrorAddingPeer(Exception): pass @@ -665,7 +672,7 @@ class LNWallet(LNWorker): self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level self.sent_buckets = dict() # payment_secret -> (amount_sent, amount_failed) - self.received_mpp_htlcs = dict() # RHASH -> mpp_status, htlc_set + self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) # detect inflight payments @@ -676,7 +683,8 @@ class LNWallet(LNWorker): self.trampoline_forwarding_failures = {} # todo: should be persisted # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] - self.hold_invoice_callbacks = {} # payment_hash -> callback, timeout + # payment_hash -> callback, timeout: + self.hold_invoice_callbacks = {} # type: Dict[bytes, Tuple[Callable[[bytes], None], int]] self.payment_bundles = [] # lists of hashes. todo:persist @@ -1891,11 +1899,13 @@ class LNWallet(LNWorker): amount_msat, direction, status = self.payment_info[key] return PaymentInfo(payment_hash, amount_msat, direction, status) - def add_payment_info_for_hold_invoice(self, payment_hash, lightning_amount_sat): + def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: int): info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID) self.save_payment_info(info, write_to_disk=False) - def register_callback_for_hold_invoice(self, payment_hash, cb, timeout: Optional[int] = None): + def register_callback_for_hold_invoice( + self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int, + ): expiry = int(time.time()) + timeout self.hold_invoice_callbacks[payment_hash] = cb, expiry @@ -1907,7 +1917,12 @@ class LNWallet(LNWorker): if write_to_disk: self.wallet.save_db() - def check_received_htlc(self, payment_secret, short_channel_id, htlc: UpdateAddHtlc, expected_msat: int) -> Optional[bool]: + def check_received_htlc( + self, payment_secret: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + expected_msat: int, + ) -> Optional[bool]: """ return MPP status: True (accepted), False (expired) or None (waiting) """ payment_hash = htlc.payment_hash @@ -1952,47 +1967,64 @@ class LNWallet(LNWorker): self.maybe_cleanup_mpp_status(payment_secret, short_channel_id, htlc) return True if is_accepted else (False if is_expired else None) - def update_mpp_with_received_htlc(self, payment_secret, short_channel_id, htlc, expected_msat): + def update_mpp_with_received_htlc( + self, + payment_secret: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + expected_msat: int, + ): # add new htlc to set - is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs.get(payment_secret, (False, False, expected_msat, set())) - assert expected_msat == _expected_msat + mpp_status = self.received_mpp_htlcs.get(payment_secret) + if mpp_status is None: + mpp_status = ReceivedMPPStatus( + is_expired=False, + is_accepted=False, + expected_msat=expected_msat, + htlc_set=set(), + ) + assert expected_msat == mpp_status.expected_msat key = (short_channel_id, htlc) - if key not in htlc_set: - htlc_set.add(key) - self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set - - def get_mpp_status(self, payment_secret): - is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] - return is_expired, is_accepted - - def set_mpp_status(self, payment_secret, is_expired, is_accepted): - _is_expired, _is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] - self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set + if key not in mpp_status.htlc_set: + mpp_status.htlc_set.add(key) # side-effecting htlc_set + self.received_mpp_htlcs[payment_secret] = mpp_status + + def get_mpp_status(self, payment_secret: bytes) -> Tuple[bool, bool]: + mpp_status = self.received_mpp_htlcs[payment_secret] + return mpp_status.is_expired, mpp_status.is_accepted + + def set_mpp_status(self, payment_secret: bytes, is_expired: bool, is_accepted: bool): + mpp_status = self.received_mpp_htlcs[payment_secret] + self.received_mpp_htlcs[payment_secret] = mpp_status._replace( + is_expired=is_expired, + is_accepted=is_accepted, + ) - def is_mpp_amount_reached(self, payment_secret): - mpp = self.received_mpp_htlcs.get(payment_secret) - if not mpp: + def is_mpp_amount_reached(self, payment_secret: bytes) -> bool: + mpp_status = self.received_mpp_htlcs.get(payment_secret) + if not mpp_status: return False - is_expired, is_accepted, _expected_msat, htlc_set = mpp - total = sum([_htlc.amount_msat for scid, _htlc in htlc_set]) - return total >= _expected_msat + total = sum([_htlc.amount_msat for scid, _htlc in mpp_status.htlc_set]) + return total >= mpp_status.expected_msat - def get_first_timestamp_of_mpp(self, payment_secret): - mpp = self.received_mpp_htlcs.get(payment_secret) - if not mpp: + def get_first_timestamp_of_mpp(self, payment_secret: bytes) -> int: + mpp_status = self.received_mpp_htlcs.get(payment_secret) + if not mpp_status: return int(time.time()) - is_expired, is_accepted, _expected_msat, htlc_set = mpp - return min([_htlc.timestamp for scid, _htlc in htlc_set]) + return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) - def maybe_cleanup_mpp_status(self, payment_secret, short_channel_id, htlc): - is_expired, is_accepted, _expected_msat, htlc_set = self.received_mpp_htlcs[payment_secret] - if not is_accepted and not is_expired: + def maybe_cleanup_mpp_status( + self, + payment_secret: bytes, + short_channel_id: ShortChannelID, + htlc: UpdateAddHtlc, + ) -> None: + mpp_status = self.received_mpp_htlcs[payment_secret] + if not mpp_status.is_accepted and not mpp_status.is_expired: return key = (short_channel_id, htlc) - htlc_set.remove(key) - if len(htlc_set) > 0: - self.received_mpp_htlcs[payment_secret] = is_expired, is_accepted, _expected_msat, htlc_set - elif payment_secret in self.received_mpp_htlcs: + mpp_status.htlc_set.remove(key) # side-effecting htlc_set + if not mpp_status.htlc_set and payment_secret in self.received_mpp_htlcs: self.received_mpp_htlcs.pop(payment_secret) def get_payment_status(self, payment_hash: bytes) -> int: