@ -87,6 +87,7 @@ from .submarine_swaps import HttpSwapManager
from . channel_db import ChannelInfo , Policy
from . mpp_split import suggest_splits , SplitConfigRating
from . trampoline import create_trampoline_route_and_onion , is_legacy_relay
from . json_db import stored_in
if TYPE_CHECKING :
from . network import Network
@ -169,11 +170,13 @@ class PaymentInfo(NamedTuple):
status : int
class RecvMPPResolution ( Enum ) :
WAITING = enum . auto ( )
EXPIRED = enum . auto ( )
ACCEPTED = enum . auto ( )
FAILED = enum . auto ( )
# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
class RecvMPPResolution ( IntEnum ) :
WAITING = 0
EXPIRED = 1
ACCEPTED = 2
FAILED = 3
class ReceivedMPPStatus ( NamedTuple ) :
@ -181,6 +184,13 @@ class ReceivedMPPStatus(NamedTuple):
expected_msat : int
htlc_set : Set [ Tuple [ ShortChannelID , UpdateAddHtlc ] ]
@stored_in ( ' received_mpp_htlcs ' , tuple )
def from_tuple ( resolution , expected_msat , htlc_list ) - > ' ReceivedMPPStatus ' :
htlc_set = set ( [ ( ShortChannelID ( bytes . fromhex ( scid ) ) , UpdateAddHtlc . from_tuple ( * x ) ) for ( scid , x ) in htlc_list ] )
return ReceivedMPPStatus (
resolution = RecvMPPResolution ( resolution ) ,
expected_msat = expected_msat ,
htlc_set = htlc_set )
SentHtlcKey = Tuple [ bytes , ShortChannelID , int ] # RHASH, scid, htlc_id
@ -851,7 +861,7 @@ class LNWallet(LNWorker):
self . _paysessions = dict ( ) # type: Dict[bytes, PaySession]
self . sent_htlcs_info = dict ( ) # type: Dict[SentHtlcKey, SentHtlcInfo]
self . received_mpp_htlcs = dict ( ) # type: Dict[bytes , ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self . received_mpp_htlcs = self . db . get_dict ( ' received_mpp_htlcs ' ) # type: Dict[str , ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments
self . inflight_payments = set ( ) # (not persisted) keys of invoices that are in PR_INFLIGHT state
@ -2192,7 +2202,7 @@ class LNWallet(LNWorker):
payment_keys = [ self . _get_payment_key ( x ) for x in hash_list ]
self . payment_bundles . append ( payment_keys )
def get_payment_bundle ( self , payment_key ) :
def get_payment_bundle ( self , payment_key : bytes ) - > Sequence [ bytes ] :
for key_list in self . payment_bundles :
if payment_key in key_list :
return key_list
@ -2259,7 +2269,7 @@ class LNWallet(LNWorker):
payment_key = payment_hash + payment_secret
self . update_mpp_with_received_htlc (
payment_key = payment_key , scid = short_channel_id , htlc = htlc , expected_msat = expected_msat )
mpp_resolution = self . received_mpp_htlcs [ payment_key ] . resolution
mpp_resolution = self . received_mpp_htlcs [ payment_key . hex ( ) ] . resolution
# if still waiting, calc resolution now:
if mpp_resolution == RecvMPPResolution . WAITING :
bundle = self . get_payment_bundle ( payment_key )
@ -2280,7 +2290,7 @@ class LNWallet(LNWorker):
# save resolution, if any.
if mpp_resolution != RecvMPPResolution . WAITING :
for pkey in payment_keys :
if pkey in self . received_mpp_htlcs :
if pkey . hex ( ) in self . received_mpp_htlcs :
self . set_mpp_resolution ( payment_key = pkey , resolution = mpp_resolution )
return mpp_resolution
@ -2294,7 +2304,7 @@ class LNWallet(LNWorker):
expected_msat : int ,
) :
# add new htlc to set
mpp_status = self . received_mpp_htlcs . get ( payment_key )
mpp_status = self . received_mpp_htlcs . get ( payment_key . hex ( ) )
if mpp_status is None :
mpp_status = ReceivedMPPStatus (
resolution = RecvMPPResolution . WAITING ,
@ -2308,47 +2318,46 @@ class LNWallet(LNWorker):
key = ( scid , htlc )
if key not in mpp_status . htlc_set :
mpp_status . htlc_set . add ( key ) # side-effecting htlc_set
self . received_mpp_htlcs [ payment_key ] = mpp_status
self . received_mpp_htlcs [ payment_key . hex ( ) ] = mpp_status
def set_mpp_resolution ( self , * , payment_key : bytes , resolution : RecvMPPResolution ) :
mpp_status = self . received_mpp_htlcs [ payment_key ]
self . received_mpp_htlcs [ payment_key ] = mpp_status . _replace ( resolution = resolution )
mpp_status = self . received_mpp_htlcs [ payment_key . hex ( ) ]
self . logger . info ( f ' set_mpp_resolution { resolution . name } { len ( mpp_status . htlc_set ) } { payment_key . hex ( ) } ' )
self . received_mpp_htlcs [ payment_key . hex ( ) ] = mpp_status . _replace ( resolution = resolution )
def is_mpp_amount_reached ( self , payment_key : bytes ) - > bool :
mpp_status = self . received_mpp_htlcs . get ( payment_key )
mpp_status = self . received_mpp_htlcs . get ( payment_key . hex ( ) )
if not mpp_status :
return False
total = sum ( [ _htlc . amount_msat for scid , _htlc in mpp_status . htlc_set ] )
return total > = mpp_status . expected_msat
def get_first_timestamp_of_mpp ( self , payment_key : bytes ) - > int :
mpp_status = self . received_mpp_htlcs . get ( payment_key )
mpp_status = self . received_mpp_htlcs . get ( payment_key . hex ( ) )
if not mpp_status :
return int ( time . time ( ) )
return min ( [ _htlc . timestamp for scid , _htlc in mpp_status . htlc_set ] )
def maybe_cleanup_forwarding (
def maybe_cleanup_mpp (
self ,
payment_key_hex : str ,
short_channel_id : ShortChannelID ,
htlc : UpdateAddHtlc ,
) - > None :
is_htlc_key = ' : ' in payment_key_hex
if not is_htlc_key :
payment_key = bytes . fromhex ( payment_key_hex )
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if not mpp_status or mpp_status . resolution == RecvMPPResolution . WAITING :
# After restart, self.received_mpp_htlcs needs to be reconstructed
self . logger . info ( f ' maybe_cleanup_forwarding: mpp_status not ready ' )
return
htlc_key = ( short_channel_id , htlc )
) - > Sequence [ str ] :
htlc_key = ( short_channel_id , htlc )
cleanup_keys = [ ]
for payment_key_hex , mpp_status in list ( self . received_mpp_htlcs . items ( ) ) :
if htlc_key not in mpp_status . htlc_set :
continue
assert mpp_status . resolution != RecvMPPResolution . WAITING
self . logger . info ( f ' maybe_cleanup_mpp: removing htlc of MPP { payment_key_hex } ' )
mpp_status . htlc_set . remove ( htlc_key ) # side-effecting htlc_set
if mpp_status . htlc_set :
return
self . logger . info ( ' cleaning up mpp ' )
self . received_mpp_htlcs . pop ( payment_key )
if len ( mpp_status . htlc_set ) == 0 :
self . logger . info ( f ' maybe_cleanup_mpp: removing mpp { payment_key_hex } ' )
self . received_mpp_htlcs . pop ( payment_key_hex )
cleanup_keys . append ( payment_key_hex )
return cleanup_keys
def maybe_cleanup_forwarding ( self , payment_key_hex : str ) - > None :
self . active_forwardings . pop ( payment_key_hex , None )
self . forwarding_failures . pop ( payment_key_hex , None )