@ -307,6 +307,8 @@ PRIMARY KEY(node_id)
class ChannelDB ( SqlDB ) :
NUM_MAX_RECENT_PEERS = 20
PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL = 600
PRIVATE_CHAN_UPD_CACHE_TTL_SHORT = 120
def __init__ ( self , network : ' Network ' ) :
path = self . get_file_path ( network . config )
@ -314,7 +316,10 @@ class ChannelDB(SqlDB):
self . lock = threading . RLock ( )
self . num_nodes = 0
self . num_channels = 0
self . _channel_updates_for_private_channels = { } # type: Dict[Tuple[bytes, bytes], dict]
self . _channel_updates_for_private_channels = { } # type: Dict[Tuple[bytes, bytes], Tuple[dict, int]]
# note: ^ we could maybe move this cache into PaySession instead of being global.
# That would only make sense though if PaySessions were never too short
# (e.g. consider trampoline forwarding).
self . ca_verifier = LNChannelVerifier ( network , self )
# initialized in load_data
@ -667,19 +672,46 @@ class ChannelDB(SqlDB):
self . update_counts ( )
self . logger . info ( f ' Deleting { len ( orphaned_chans ) } orphaned channels ' )
def add_channel_update_for_private_channel ( self , msg_payload : dict , start_node_id : bytes ) - > bool :
def _get_channel_update_for_private_channel (
self ,
start_node_id : bytes ,
short_channel_id : ShortChannelID ,
* ,
now : int = None , # unix ts
) - > Optional [ dict ] :
if now is None :
now = int ( time . time ( ) )
key = ( start_node_id , short_channel_id )
chan_upd_dict , cache_expiration = self . _channel_updates_for_private_channels . get ( key , ( None , 0 ) )
if cache_expiration < now :
chan_upd_dict = None # already expired
# TODO rm expired entries from cache (note: perf vs thread-safety)
return chan_upd_dict
def add_channel_update_for_private_channel (
self ,
msg_payload : dict ,
start_node_id : bytes ,
* ,
cache_ttl : int = None , # seconds
) - > bool :
""" Returns True iff the channel update was successfully added and it was different than
what we had before ( if any ) .
"""
if not verify_sig_for_channel_update ( msg_payload , start_node_id ) :
return False # ignore
now = int ( time . time ( ) )
short_channel_id = ShortChannelID ( msg_payload [ ' short_channel_id ' ] )
msg_payload [ ' start_node ' ] = start_node_id
key = ( start_node_id , short_channel_id )
prev_chanupd = self . _channel_updates_for_private_channels . get ( key )
prev_chanupd = self . _get_channel_update_for_private_channel ( start_node_id , short_channel_id , now = now )
if prev_chanupd == msg_payload :
return False
self . _channel_updates_for_private_channels [ key ] = msg_payload
if cache_ttl is None :
cache_ttl = self . PRIVATE_CHAN_UPD_CACHE_TTL_NORMAL
cache_expiration = now + cache_ttl
key = ( start_node_id , short_channel_id )
with self . lock :
self . _channel_updates_for_private_channels [ key ] = msg_payload , cache_expiration
return True
def remove_channel ( self , short_channel_id : ShortChannelID ) :
@ -791,21 +823,20 @@ class ChannelDB(SqlDB):
def get_policy_for_node (
self ,
short_channel_id : bytes ,
short_channel_id : ShortChannelID ,
node_id : bytes ,
* ,
my_channels : Dict [ ShortChannelID , ' Channel ' ] = None ,
private_route_edges : Dict [ ShortChannelID , ' RouteEdge ' ] = None ,
now : int = None , # unix ts
) - > Optional [ ' Policy ' ] :
channel_info = self . get_channel_info ( short_channel_id )
if channel_info is not None : # publicly announced channel
policy = self . _policies . get ( ( node_id , short_channel_id ) )
if policy :
return policy
else : # private channel
chan_upd_dict = self . _channel_updates_for_private_channels . get ( ( node_id , short_channel_id ) )
if chan_upd_dict :
return Policy . from_msg ( chan_upd_dict )
elif chan_upd_dict := self . _get_channel_update_for_private_channel ( node_id , short_channel_id , now = now ) :
return Policy . from_msg ( chan_upd_dict )
# check if it's one of our own channels
if my_channels :
policy = get_mychannel_policy ( short_channel_id , node_id , my_channels )