@ -8,7 +8,8 @@ from decimal import Decimal
import random
import time
import operator
from enum import IntEnum
import enum
from enum import IntEnum , Enum
from typing import ( Optional , Sequence , Tuple , List , Set , Dict , TYPE_CHECKING ,
NamedTuple , Union , Mapping , Any , Iterable , AsyncGenerator , DefaultDict , Callable )
import threading
@ -167,9 +168,15 @@ class PaymentInfo(NamedTuple):
status : int
class RecvMPPResolution ( Enum ) :
WAITING = enum . auto ( )
EXPIRED = enum . auto ( )
ACCEPTED = enum . auto ( )
FAILED = enum . auto ( )
class ReceivedMPPStatus ( NamedTuple ) :
is_expired : bool
is_accepted : bool
resolution : RecvMPPResolution
expected_msat : int
htlc_set : Set [ Tuple [ ShortChannelID , UpdateAddHtlc ] ]
@ -673,8 +680,8 @@ class LNWallet(LNWorker):
self . sent_htlcs = defaultdict ( asyncio . Queue ) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self . sent_htlcs_info = dict ( ) # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self . sent_buckets = dict ( ) # payment_secret -> (amount_sent, amount_failed)
self . received_mpp_htlcs = dict ( ) # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus
self . sent_buckets = dict ( ) # payment_key -> (amount_sent, amount_failed)
self . received_mpp_htlcs = dict ( ) # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self . swap_manager = SwapManager ( wallet = self . wallet , lnworker = self )
# detect inflight payments
@ -1418,13 +1425,14 @@ class LNWallet(LNWorker):
key = ( payment_hash , short_channel_id , htlc . htlc_id )
self . sent_htlcs_info [ key ] = route , payment_secret , amount_msat , total_msat , amount_receiver_msat , trampoline_fee_level , trampoline_route
payment_key = payment_hash + payment_secret
# if we sent MPP to a trampoline, add item to sent_buckets
if self . uses_trampoline ( ) and amount_msat != total_msat :
if payment_secret not in self . sent_buckets :
self . sent_buckets [ payment_secret ] = ( 0 , 0 )
amount_sent , amount_failed = self . sent_buckets [ payment_secret ]
if payment_key not in self . sent_buckets :
self . sent_buckets [ payment_key ] = ( 0 , 0 )
amount_sent , amount_failed = self . sent_buckets [ payment_key ]
amount_sent + = amount_receiver_msat
self . sent_buckets [ payment_secret ] = amount_sent , amount_failed
self . sent_buckets [ payment_key ] = amount_sent , amount_failed
if self . network . path_finder :
# add inflight htlcs to liquidity hints
self . network . path_finder . update_inflight_htlcs ( route , add_htlcs = True )
@ -1867,6 +1875,14 @@ class LNWallet(LNWorker):
def get_payment_secret ( self , payment_hash ) :
return sha256 ( sha256 ( self . payment_secret_key ) + payment_hash )
def _get_payment_key ( self , payment_hash : bytes ) - > bytes :
""" Return payment bucket key.
We bucket htlcs based on payment_hash + payment_secret . payment_secret is included
as it changes over a trampoline path ( in the outer onion ) , and these paths can overlap .
"""
payment_secret = self . get_payment_secret ( payment_hash )
return payment_hash + payment_secret
def create_payment_info ( self , * , amount_msat : Optional [ int ] , write_to_disk = True ) - > bytes :
payment_preimage = os . urandom ( 32 )
payment_hash = sha256 ( payment_preimage )
@ -1923,103 +1939,101 @@ class LNWallet(LNWorker):
self . wallet . save_db ( )
def check_mpp_status (
self , payment_secret : bytes ,
self , * ,
payment_secret : bytes ,
short_channel_id : ShortChannelID ,
htlc : UpdateAddHtlc ,
expected_msat : int ,
) - > Optional [ bool ] :
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
) - > RecvMPPResolution :
payment_hash = htlc . payment_hash
self . update_mpp_with_received_htlc ( payment_secret , short_channel_id , htlc , expected_msat )
is_expired , is_accepted = self . get_mpp_status ( payment_secret )
if not is_accepted and not is_expired :
payment_key = payment_hash + payment_secret
self . update_mpp_with_received_htlc (
payment_key = payment_key , scid = short_channel_id , htlc = htlc , expected_msat = expected_msat )
mpp_resolution = self . received_mpp_htlcs [ payment_key ] . resolution
if mpp_resolution == RecvMPPResolution . WAITING :
bundle = self . get_payment_bundle ( payment_hash )
if bundle :
payment_secret s = [ self . get_payment_secret ( h ) for h in bundle ]
if payment_secret not in payment_secret s :
payment_key s = [ self . _get_payment_key ( h ) for h in bundle ]
if payment_key not in payment_key s :
# outer trampoline onion secret differs from inner onion
# the latter, not the former, might be part of a bundle
payment_secrets = [ payment_secret ]
payment_keys = [ payment_key ]
else :
payment_secrets = [ payment_secret ]
first_timestamp = min ( [ self . get_first_timestamp_of_mpp ( x ) for x in payment_secret s] )
payment_keys = [ payment_key ]
first_timestamp = min ( [ self . get_first_timestamp_of_mpp ( pkey ) for pkey in payment_key s] )
if self . get_payment_status ( payment_hash ) == PR_PAID :
is_accepted = True
mpp_resolution = RecvMPPResolution . ACCEPTED
elif self . stopping_soon :
is_expired = True # try to time out pending HTLCs before shutting down
elif all ( [ self . is_mpp_amount_reached ( x ) for x in payment_secrets ] ) :
is_accepted = True
# try to time out pending HTLCs before shutting down
mpp_resolution = RecvMPPResolution . EXPIRED
elif all ( [ self . is_mpp_amount_reached ( pkey ) for pkey in payment_keys ] ) :
mpp_resolution = RecvMPPResolution . ACCEPTED
elif time . time ( ) - first_timestamp > self . MPP_EXPIRY :
is_expired = True
mpp_resolution = RecvMPPResolution . EXPIRED
if is_accepted or is_expired :
for x in payment_secret s:
if x in self . received_mpp_htlcs :
self . set_mpp_status ( x , is_expired , is_accepted )
if mpp_resolution != RecvMPPResolution . WAITING :
for pkey in payment_key s:
if pkey in self . received_mpp_htlcs :
self . set_mpp_resolution ( payment_key = pkey , resolution = mpp_resolution )
self . maybe_cleanup_mpp_status ( payment_secret , short_channel_id , htlc )
return True if is_accepted else ( False if is_expired else None )
self . maybe_cleanup_mpp_status ( payment_key , short_channel_id , htlc )
return mpp_resolution
def update_mpp_with_received_htlc (
self ,
payment_secret : bytes ,
short_channel_id : ShortChannelID ,
* ,
payment_key : bytes ,
scid : ShortChannelID ,
htlc : UpdateAddHtlc ,
expected_msat : int ,
) :
# add new htlc to set
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if mpp_status is None :
mpp_status = ReceivedMPPStatus (
is_expired = False ,
is_accepted = False ,
resolution = RecvMPPResolution . WAITING ,
expected_msat = expected_msat ,
htlc_set = set ( ) ,
)
assert expected_msat == mpp_status . expected_msat
key = ( short_channel_id , htlc )
if expected_msat != mpp_status . expected_msat :
self . logger . info (
f " marking received mpp as failed. inconsistent total_msats in bucket. { payment_key . hex ( ) =} " )
mpp_status = mpp_status . _replace ( resolution = RecvMPPResolution . FAILED )
key = ( scid , htlc )
if key not in mpp_status . htlc_set :
mpp_status . htlc_set . add ( key ) # side-effecting htlc_set
self . received_mpp_htlcs [ payment_secret ] = mpp_status
def get_mpp_status ( self , payment_secret : bytes ) - > Tuple [ bool , bool ] :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
return mpp_status . is_expired , mpp_status . is_accepted
self . received_mpp_htlcs [ payment_key ] = mpp_status
def set_mpp_status ( self , payment_secret : bytes , is_expired : bool , is_accepted : bool ) :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
self . received_mpp_htlcs [ payment_secret ] = mpp_status . _replace (
is_expired = is_expired ,
is_accepted = is_accepted ,
)
def set_mpp_resolution ( self , * , payment_key : bytes , resolution : RecvMPPResolution ) :
mpp_status = self . received_mpp_htlcs [ payment_key ]
self . received_mpp_htlcs [ payment_key ] = mpp_status . _replace ( resolution = resolution )
def is_mpp_amount_reached ( self , payment_secret : bytes ) - > bool :
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
def is_mpp_amount_reached ( self , payment_key : bytes ) - > bool :
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if not mpp_status :
return False
total = sum ( [ _htlc . amount_msat for scid , _htlc in mpp_status . htlc_set ] )
return total > = mpp_status . expected_msat
def get_first_timestamp_of_mpp ( self , payment_secret : bytes ) - > int :
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
def get_first_timestamp_of_mpp ( self , payment_key : bytes ) - > int :
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if not mpp_status :
return int ( time . time ( ) )
return min ( [ _htlc . timestamp for scid , _htlc in mpp_status . htlc_set ] )
def maybe_cleanup_mpp_status (
self ,
payment_secret : bytes ,
payment_key : bytes ,
short_channel_id : ShortChannelID ,
htlc : UpdateAddHtlc ,
) - > None :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
if not mpp_status . is_accepted and not mpp_status . is_expired :
mpp_status = self . received_mpp_htlcs [ payment_key ]
if mpp_status . resolution == RecvMPPResolution . WAITING :
return
key = ( short_channel_id , htlc )
mpp_status . htlc_set . remove ( key ) # side-effecting htlc_set
if not mpp_status . htlc_set and payment_secret in self . received_mpp_htlcs :
self . received_mpp_htlcs . pop ( payment_secret )
if not mpp_status . htlc_set and payment_key in self . received_mpp_htlcs :
self . received_mpp_htlcs . pop ( payment_key )
def get_payment_status ( self , payment_hash : bytes ) - > int :
info = self . get_payment_info ( payment_hash )
@ -2126,10 +2140,11 @@ class LNWallet(LNWorker):
self . logger . info ( f " htlc_failed { failure_message } " )
# check sent_buckets if we use trampoline
if self . uses_trampoline ( ) and payment_secret in self . sent_buckets :
amount_sent , amount_failed = self . sent_buckets [ payment_secret ]
payment_key = payment_hash + payment_secret
if self . uses_trampoline ( ) and payment_key in self . sent_buckets :
amount_sent , amount_failed = self . sent_buckets [ payment_key ]
amount_failed + = amount_receiver_msat
self . sent_buckets [ payment_secret ] = amount_sent , amount_failed
self . sent_buckets [ payment_key ] = amount_sent , amount_failed
if amount_sent != amount_failed :
self . logger . info ( ' bucket still active... ' )
return