@ -8,7 +8,8 @@ from decimal import Decimal
import random
import random
import time
import time
import operator
import operator
from enum import IntEnum
import enum
from enum import IntEnum , Enum
from typing import ( Optional , Sequence , Tuple , List , Set , Dict , TYPE_CHECKING ,
from typing import ( Optional , Sequence , Tuple , List , Set , Dict , TYPE_CHECKING ,
NamedTuple , Union , Mapping , Any , Iterable , AsyncGenerator , DefaultDict , Callable )
NamedTuple , Union , Mapping , Any , Iterable , AsyncGenerator , DefaultDict , Callable )
import threading
import threading
@ -167,9 +168,15 @@ class PaymentInfo(NamedTuple):
status : int
status : int
class RecvMPPResolution ( Enum ) :
WAITING = enum . auto ( )
EXPIRED = enum . auto ( )
ACCEPTED = enum . auto ( )
FAILED = enum . auto ( )
class ReceivedMPPStatus ( NamedTuple ) :
class ReceivedMPPStatus ( NamedTuple ) :
is_expired : bool
resolution : RecvMPPResolution
is_accepted : bool
expected_msat : int
expected_msat : int
htlc_set : Set [ Tuple [ ShortChannelID , UpdateAddHtlc ] ]
htlc_set : Set [ Tuple [ ShortChannelID , UpdateAddHtlc ] ]
@ -673,8 +680,8 @@ class LNWallet(LNWorker):
self . sent_htlcs = defaultdict ( asyncio . Queue ) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self . sent_htlcs = defaultdict ( asyncio . Queue ) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self . sent_htlcs_info = dict ( ) # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self . sent_htlcs_info = dict ( ) # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self . sent_buckets = dict ( ) # payment_secret -> (amount_sent, amount_failed)
self . sent_buckets = dict ( ) # payment_key -> (amount_sent, amount_failed)
self . received_mpp_htlcs = dict ( ) # type: Dict[bytes, ReceivedMPPStatus] # payment_secret -> ReceivedMPPStatus
self . received_mpp_htlcs = dict ( ) # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self . swap_manager = SwapManager ( wallet = self . wallet , lnworker = self )
self . swap_manager = SwapManager ( wallet = self . wallet , lnworker = self )
# detect inflight payments
# detect inflight payments
@ -1418,13 +1425,14 @@ class LNWallet(LNWorker):
key = ( payment_hash , short_channel_id , htlc . htlc_id )
key = ( payment_hash , short_channel_id , htlc . htlc_id )
self . sent_htlcs_info [ key ] = route , payment_secret , amount_msat , total_msat , amount_receiver_msat , trampoline_fee_level , trampoline_route
self . sent_htlcs_info [ key ] = route , payment_secret , amount_msat , total_msat , amount_receiver_msat , trampoline_fee_level , trampoline_route
payment_key = payment_hash + payment_secret
# if we sent MPP to a trampoline, add item to sent_buckets
# if we sent MPP to a trampoline, add item to sent_buckets
if self . uses_trampoline ( ) and amount_msat != total_msat :
if self . uses_trampoline ( ) and amount_msat != total_msat :
if payment_secret not in self . sent_buckets :
if payment_key not in self . sent_buckets :
self . sent_buckets [ payment_secret ] = ( 0 , 0 )
self . sent_buckets [ payment_key ] = ( 0 , 0 )
amount_sent , amount_failed = self . sent_buckets [ payment_secret ]
amount_sent , amount_failed = self . sent_buckets [ payment_key ]
amount_sent + = amount_receiver_msat
amount_sent + = amount_receiver_msat
self . sent_buckets [ payment_secret ] = amount_sent , amount_failed
self . sent_buckets [ payment_key ] = amount_sent , amount_failed
if self . network . path_finder :
if self . network . path_finder :
# add inflight htlcs to liquidity hints
# add inflight htlcs to liquidity hints
self . network . path_finder . update_inflight_htlcs ( route , add_htlcs = True )
self . network . path_finder . update_inflight_htlcs ( route , add_htlcs = True )
@ -1867,6 +1875,14 @@ class LNWallet(LNWorker):
def get_payment_secret ( self , payment_hash ) :
def get_payment_secret ( self , payment_hash ) :
return sha256 ( sha256 ( self . payment_secret_key ) + payment_hash )
return sha256 ( sha256 ( self . payment_secret_key ) + payment_hash )
def _get_payment_key ( self , payment_hash : bytes ) - > bytes :
""" Return payment bucket key.
We bucket htlcs based on payment_hash + payment_secret . payment_secret is included
as it changes over a trampoline path ( in the outer onion ) , and these paths can overlap .
"""
payment_secret = self . get_payment_secret ( payment_hash )
return payment_hash + payment_secret
def create_payment_info ( self , * , amount_msat : Optional [ int ] , write_to_disk = True ) - > bytes :
def create_payment_info ( self , * , amount_msat : Optional [ int ] , write_to_disk = True ) - > bytes :
payment_preimage = os . urandom ( 32 )
payment_preimage = os . urandom ( 32 )
payment_hash = sha256 ( payment_preimage )
payment_hash = sha256 ( payment_preimage )
@ -1923,103 +1939,101 @@ class LNWallet(LNWorker):
self . wallet . save_db ( )
self . wallet . save_db ( )
def check_mpp_status (
def check_mpp_status (
self , payment_secret : bytes ,
self , * ,
payment_secret : bytes ,
short_channel_id : ShortChannelID ,
short_channel_id : ShortChannelID ,
htlc : UpdateAddHtlc ,
htlc : UpdateAddHtlc ,
expected_msat : int ,
expected_msat : int ,
) - > Optional [ bool ] :
) - > RecvMPPResolution :
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
payment_hash = htlc . payment_hash
payment_hash = htlc . payment_hash
self . update_mpp_with_received_htlc ( payment_secret , short_channel_id , htlc , expected_msat )
payment_key = payment_hash + payment_secret
is_expired , is_accepted = self . get_mpp_status ( payment_secret )
self . update_mpp_with_received_htlc (
if not is_accepted and not is_expired :
payment_key = payment_key , scid = short_channel_id , htlc = htlc , expected_msat = expected_msat )
mpp_resolution = self . received_mpp_htlcs [ payment_key ] . resolution
if mpp_resolution == RecvMPPResolution . WAITING :
bundle = self . get_payment_bundle ( payment_hash )
bundle = self . get_payment_bundle ( payment_hash )
if bundle :
if bundle :
payment_secret s = [ self . get_payment_secret ( h ) for h in bundle ]
payment_key s = [ self . _get_payment_key ( h ) for h in bundle ]
if payment_secret not in payment_secret s :
if payment_key not in payment_key s :
# outer trampoline onion secret differs from inner onion
# outer trampoline onion secret differs from inner onion
# the latter, not the former, might be part of a bundle
# the latter, not the former, might be part of a bundle
payment_secrets = [ payment_secret ]
payment_keys = [ payment_key ]
else :
else :
payment_secrets = [ payment_secret ]
payment_keys = [ payment_key ]
first_timestamp = min ( [ self . get_first_timestamp_of_mpp ( x ) for x in payment_secret s] )
first_timestamp = min ( [ self . get_first_timestamp_of_mpp ( pkey ) for pkey in payment_key s] )
if self . get_payment_status ( payment_hash ) == PR_PAID :
if self . get_payment_status ( payment_hash ) == PR_PAID :
is_accepted = True
mpp_resolution = RecvMPPResolution . ACCEPTED
elif self . stopping_soon :
elif self . stopping_soon :
is_expired = True # try to time out pending HTLCs before shutting down
# try to time out pending HTLCs before shutting down
elif all ( [ self . is_mpp_amount_reached ( x ) for x in payment_secrets ] ) :
mpp_resolution = RecvMPPResolution . EXPIRED
is_accepted = True
elif all ( [ self . is_mpp_amount_reached ( pkey ) for pkey in payment_keys ] ) :
mpp_resolution = RecvMPPResolution . ACCEPTED
elif time . time ( ) - first_timestamp > self . MPP_EXPIRY :
elif time . time ( ) - first_timestamp > self . MPP_EXPIRY :
is_expired = True
mpp_resolution = RecvMPPResolution . EXPIRED
if is_accepted or is_expired :
if mpp_resolution != RecvMPPResolution . WAITING :
for x in payment_secret s:
for pkey in payment_key s:
if x in self . received_mpp_htlcs :
if pkey in self . received_mpp_htlcs :
self . set_mpp_status ( x , is_expired , is_accepted )
self . set_mpp_resolution ( payment_key = pkey , resolution = mpp_resolution )
self . maybe_cleanup_mpp_status ( payment_secret , short_channel_id , htlc )
self . maybe_cleanup_mpp_status ( payment_key , short_channel_id , htlc )
return True if is_accepted else ( False if is_expired else None )
return mpp_resolution
def update_mpp_with_received_htlc (
def update_mpp_with_received_htlc (
self ,
self ,
payment_secret : bytes ,
* ,
short_channel_id : ShortChannelID ,
payment_key : bytes ,
scid : ShortChannelID ,
htlc : UpdateAddHtlc ,
htlc : UpdateAddHtlc ,
expected_msat : int ,
expected_msat : int ,
) :
) :
# add new htlc to set
# add new htlc to set
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if mpp_status is None :
if mpp_status is None :
mpp_status = ReceivedMPPStatus (
mpp_status = ReceivedMPPStatus (
is_expired = False ,
resolution = RecvMPPResolution . WAITING ,
is_accepted = False ,
expected_msat = expected_msat ,
expected_msat = expected_msat ,
htlc_set = set ( ) ,
htlc_set = set ( ) ,
)
)
assert expected_msat == mpp_status . expected_msat
if expected_msat != mpp_status . expected_msat :
key = ( short_channel_id , htlc )
self . logger . info (
f " marking received mpp as failed. inconsistent total_msats in bucket. { payment_key . hex ( ) =} " )
mpp_status = mpp_status . _replace ( resolution = RecvMPPResolution . FAILED )
key = ( scid , htlc )
if key not in mpp_status . htlc_set :
if key not in mpp_status . htlc_set :
mpp_status . htlc_set . add ( key ) # side-effecting htlc_set
mpp_status . htlc_set . add ( key ) # side-effecting htlc_set
self . received_mpp_htlcs [ payment_secret ] = mpp_status
self . received_mpp_htlcs [ payment_key ] = mpp_status
def get_mpp_status ( self , payment_secret : bytes ) - > Tuple [ bool , bool ] :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
return mpp_status . is_expired , mpp_status . is_accepted
def set_mpp_status ( self , payment_secret : bytes , is_expired : bool , is_accepted : bool ) :
def set_mpp_resolution ( self , * , payment_key : bytes , resolution : RecvMPPResolution ) :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
mpp_status = self . received_mpp_htlcs [ payment_key ]
self . received_mpp_htlcs [ payment_secret ] = mpp_status . _replace (
self . received_mpp_htlcs [ payment_key ] = mpp_status . _replace ( resolution = resolution )
is_expired = is_expired ,
is_accepted = is_accepted ,
)
def is_mpp_amount_reached ( self , payment_secret : bytes ) - > bool :
def is_mpp_amount_reached ( self , payment_key : bytes ) - > bool :
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if not mpp_status :
if not mpp_status :
return False
return False
total = sum ( [ _htlc . amount_msat for scid , _htlc in mpp_status . htlc_set ] )
total = sum ( [ _htlc . amount_msat for scid , _htlc in mpp_status . htlc_set ] )
return total > = mpp_status . expected_msat
return total > = mpp_status . expected_msat
def get_first_timestamp_of_mpp ( self , payment_secret : bytes ) - > int :
def get_first_timestamp_of_mpp ( self , payment_key : bytes ) - > int :
mpp_status = self . received_mpp_htlcs . get ( payment_secret )
mpp_status = self . received_mpp_htlcs . get ( payment_key )
if not mpp_status :
if not mpp_status :
return int ( time . time ( ) )
return int ( time . time ( ) )
return min ( [ _htlc . timestamp for scid , _htlc in mpp_status . htlc_set ] )
return min ( [ _htlc . timestamp for scid , _htlc in mpp_status . htlc_set ] )
def maybe_cleanup_mpp_status (
def maybe_cleanup_mpp_status (
self ,
self ,
payment_secret : bytes ,
payment_key : bytes ,
short_channel_id : ShortChannelID ,
short_channel_id : ShortChannelID ,
htlc : UpdateAddHtlc ,
htlc : UpdateAddHtlc ,
) - > None :
) - > None :
mpp_status = self . received_mpp_htlcs [ payment_secret ]
mpp_status = self . received_mpp_htlcs [ payment_key ]
if not mpp_status . is_accepted and not mpp_status . is_expired :
if mpp_status . resolution == RecvMPPResolution . WAITING :
return
return
key = ( short_channel_id , htlc )
key = ( short_channel_id , htlc )
mpp_status . htlc_set . remove ( key ) # side-effecting htlc_set
mpp_status . htlc_set . remove ( key ) # side-effecting htlc_set
if not mpp_status . htlc_set and payment_secret in self . received_mpp_htlcs :
if not mpp_status . htlc_set and payment_key in self . received_mpp_htlcs :
self . received_mpp_htlcs . pop ( payment_secret )
self . received_mpp_htlcs . pop ( payment_key )
def get_payment_status ( self , payment_hash : bytes ) - > int :
def get_payment_status ( self , payment_hash : bytes ) - > int :
info = self . get_payment_info ( payment_hash )
info = self . get_payment_info ( payment_hash )
@ -2126,10 +2140,11 @@ class LNWallet(LNWorker):
self . logger . info ( f " htlc_failed { failure_message } " )
self . logger . info ( f " htlc_failed { failure_message } " )
# check sent_buckets if we use trampoline
# check sent_buckets if we use trampoline
if self . uses_trampoline ( ) and payment_secret in self . sent_buckets :
payment_key = payment_hash + payment_secret
amount_sent , amount_failed = self . sent_buckets [ payment_secret ]
if self . uses_trampoline ( ) and payment_key in self . sent_buckets :
amount_sent , amount_failed = self . sent_buckets [ payment_key ]
amount_failed + = amount_receiver_msat
amount_failed + = amount_receiver_msat
self . sent_buckets [ payment_secret ] = amount_sent , amount_failed
self . sent_buckets [ payment_key ] = amount_sent , amount_failed
if amount_sent != amount_failed :
if amount_sent != amount_failed :
self . logger . info ( ' bucket still active... ' )
self . logger . info ( ' bucket still active... ' )
return
return