Browse Source

lnworker: introduce PaySession cls, refactor pay_to_node

master
SomberNight 2 years ago
parent
commit
00e88c4e50
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 274
      electrum/lnworker.py
  2. 35
      electrum/tests/test_lnpeer.py

274
electrum/lnworker.py

@ -17,7 +17,7 @@ import socket
import aiohttp import aiohttp
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial, cached_property
from collections import defaultdict from collections import defaultdict
import concurrent import concurrent
from concurrent import futures from concurrent import futures
@ -655,6 +655,105 @@ class LNGossip(LNWorker):
self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}') self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
class PaySession(Logger):
def __init__(
self,
*,
payment_hash: bytes,
payment_secret: bytes,
initial_trampoline_fee_level: int,
invoice_features: int,
r_tags,
min_cltv_expiry: int,
amount_to_pay: int, # total payment amount final receiver will get
invoice_pubkey: bytes,
):
assert payment_hash
assert payment_secret
self.payment_hash = payment_hash
self.payment_secret = payment_secret
self.payment_key = payment_hash + payment_secret
Logger.__init__(self)
self.invoice_features = LnFeatures(invoice_features)
self.r_tags = r_tags
self.min_cltv_expiry = min_cltv_expiry
self.amount_to_pay = amount_to_pay
self.invoice_pubkey = invoice_pubkey
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
self.start_time = time.time()
self.trampoline_fee_level = initial_trampoline_fee_level
self.failed_trampoline_routes = []
self.use_two_trampolines = True
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0
def diagnostic_name(self):
pkey = sha256(self.payment_key)
return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
self.trampoline_fee_level += 1
self.failed_trampoline_routes = []
self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
else:
self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
# FIXME The trampoline nodes in the path are chosen randomly.
# Some of the errors might depend on how we have chosen them.
# Having more attempts is currently useful in part because of the randomness,
# instead we should give feedback to create_routes_for_payment.
# Sometimes the trampoline node fails to send a payment and returns
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
if failure_msg.code in (
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
# TODO: parse the node policy here (not returned by eclair yet)
# TODO: erring node is always the first trampoline even if second
# trampoline demands more fees, we can't influence this
self.maybe_raise_trampoline_fee(htlc_log)
elif self.use_two_trampolines:
self.use_two_trampolines = False
elif failure_msg.code in (
OnionFailureCode.UNKNOWN_NEXT_PEER,
OnionFailureCode.TEMPORARY_NODE_FAILURE):
trampoline_route = htlc_log.route
r = [hop.end_node.hex() for hop in trampoline_route]
self.logger.info(f'failed trampoline route: {r}')
if r not in self.failed_trampoline_routes:
self.failed_trampoline_routes.append(r)
else:
pass # maybe the route was reused between different MPP parts
else:
raise PaymentFailure(failure_msg.code_name())
async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
htlc_log = await self.sent_htlcs_q.get()
self._amount_inflight -= htlc_log.amount_msat
self._nhtlcs_inflight -= 1
if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
return htlc_log
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo:
self._nhtlcs_inflight += 1
self._amount_inflight += sent_htlc_info.amount_receiver_msat
if self._amount_inflight > self.amount_to_pay: # safety belts
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
return sent_htlc_info
def get_outstanding_amount_to_send(self) -> int:
return self.amount_to_pay - self._amount_inflight
class LNWallet(LNWorker): class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher'] lnwatcher: Optional['LNWalletWatcher']
@ -694,9 +793,9 @@ class LNWallet(LNWorker):
for channel_id, storage in channel_backups.items(): for channel_id, storage in channel_backups.items():
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self) self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments # detect inflight payments
@ -1274,9 +1373,9 @@ class LNWallet(LNWorker):
invoice_features: int, invoice_features: int,
attempts: int = None, attempts: int = None,
full_path: LNPaymentPath = None, full_path: LNPaymentPath = None,
fwd_trampoline_onion=None, fwd_trampoline_onion: OnionPacket = None,
fwd_trampoline_fee=None, fwd_trampoline_fee: int = None,
fwd_trampoline_cltv_delta=None, fwd_trampoline_cltv_delta: int = None,
channels: Optional[Sequence[Channel]] = None, channels: Optional[Sequence[Channel]] = None,
) -> None: ) -> None:
@ -1288,46 +1387,37 @@ class LNWallet(LNWorker):
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
payment_key = payment_hash + payment_secret payment_key = payment_hash + payment_secret
#assert payment_key not in self._paysessions # FIXME
self._paysessions[payment_key] = paysession = PaySession(
payment_hash=payment_hash,
payment_secret=payment_secret,
initial_trampoline_fee_level=self.INITIAL_TRAMPOLINE_FEE_LEVEL,
invoice_features=invoice_features,
r_tags=r_tags,
min_cltv_expiry=min_cltv_expiry,
amount_to_pay=amount_to_pay,
invoice_pubkey=node_pubkey,
)
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding) self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
# when encountering trampoline forwarding difficulties in the legacy case, we # when encountering trampoline forwarding difficulties in the legacy case, we
# sometimes need to fall back to a single trampoline forwarder, at the expense # sometimes need to fall back to a single trampoline forwarder, at the expense
# of privacy # of privacy
use_two_trampolines = True
trampoline_fee_level = self.INITIAL_TRAMPOLINE_FEE_LEVEL
failed_trampoline_routes = []
start_time = time.time()
amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
nhtlcs_inflight = 0
while True: while True:
amount_to_send = amount_to_pay - amount_inflight if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
if amount_to_send > 0:
# 1. create a set of routes for remaining amount. # 1. create a set of routes for remaining amount.
# note: path-finding runs in a separate thread so that we don't block the asyncio loop # note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation # graph updates might occur during the computation
routes = self.create_routes_for_payment( routes = self.create_routes_for_payment(
paysession=paysession,
amount_msat=amount_to_send, amount_msat=amount_to_send,
final_total_msat=amount_to_pay,
invoice_pubkey=node_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
full_path=full_path, full_path=full_path,
payment_hash=payment_hash,
payment_secret=payment_secret,
trampoline_fee_level=trampoline_fee_level,
failed_trampoline_routes=failed_trampoline_routes,
use_two_trampolines=use_two_trampolines,
fwd_trampoline_onion=fwd_trampoline_onion, fwd_trampoline_onion=fwd_trampoline_onion,
channels=channels, channels=channels,
) )
# 2. send htlcs # 2. send htlcs
async for sent_htlc_info, cltv_delta, trampoline_onion in routes: async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
nhtlcs_inflight += 1 sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
amount_inflight += sent_htlc_info.amount_receiver_msat
if amount_inflight > amount_to_pay: # safety belts
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=trampoline_fee_level)
await self.pay_to_route( await self.pay_to_route(
sent_htlc_info=sent_htlc_info, sent_htlc_info=sent_htlc_info,
payment_hash=payment_hash, payment_hash=payment_hash,
@ -1339,12 +1429,7 @@ class LNWallet(LNWorker):
# (e.g. attempt counter) # (e.g. attempt counter)
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
# 3. await a queue # 3. await a queue
self.logger.info(f"paysession for RHASH {payment_hash.hex()} waiting... {amount_inflight=}. {nhtlcs_inflight=}") htlc_log = await paysession.wait_for_one_htlc_to_resolve() # TODO maybe wait a bit, more failures might come
htlc_log = await self.sent_htlcs_q[payment_key].get() # TODO maybe wait a bit, more failures might come
amount_inflight -= htlc_log.amount_msat
nhtlcs_inflight -= 1
if amount_inflight < 0 or nhtlcs_inflight < 0:
raise Exception(f"{amount_inflight=}, {nhtlcs_inflight=}. both should be >= 0 !")
log.append(htlc_log) log.append(htlc_log)
if htlc_log.success: if htlc_log.success:
if self.network.path_finder: if self.network.path_finder:
@ -1357,7 +1442,7 @@ class LNWallet(LNWorker):
self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False) self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
return return
# htlc failed # htlc failed
if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - start_time > self.PAYMENT_TIMEOUT): if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
raise PaymentFailure('Giving up after %d attempts'%len(log)) raise PaymentFailure('Giving up after %d attempts'%len(log))
# if we get a tmp channel failure, it might work to split the amount and try more routes # if we get a tmp channel failure, it might work to split the amount and try more routes
# if we get a channel update, we might retry the same route and amount # if we get a channel update, we might retry the same route and amount
@ -1373,45 +1458,8 @@ class LNWallet(LNWorker):
raise PaymentFailure(failure_msg.code_name()) raise PaymentFailure(failure_msg.code_name())
# trampoline # trampoline
if self.uses_trampoline(): if self.uses_trampoline():
def maybe_raise_trampoline_fee(htlc_log): paysession.handle_failed_trampoline_htlc(
nonlocal trampoline_fee_level htlc_log=htlc_log, failure_msg=failure_msg)
nonlocal failed_trampoline_routes
if htlc_log.trampoline_fee_level == trampoline_fee_level:
trampoline_fee_level += 1
failed_trampoline_routes = []
self.logger.info(f'raising trampoline fee level {trampoline_fee_level}')
else:
self.logger.info(f'NOT raising trampoline fee level, already at {trampoline_fee_level}')
# FIXME The trampoline nodes in the path are chosen randomly.
# Some of the errors might depend on how we have chosen them.
# Having more attempts is currently useful in part because of the randomness,
# instead we should give feedback to create_routes_for_payment.
# Sometimes the trampoline node fails to send a payment and returns
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
if code in (
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
# TODO: parse the node policy here (not returned by eclair yet)
# TODO: erring node is always the first trampoline even if second
# trampoline demands more fees, we can't influence this
maybe_raise_trampoline_fee(htlc_log)
continue
elif use_two_trampolines:
use_two_trampolines = False
elif code in (
OnionFailureCode.UNKNOWN_NEXT_PEER,
OnionFailureCode.TEMPORARY_NODE_FAILURE):
trampoline_route = htlc_log.route
r = [hop.end_node.hex() for hop in trampoline_route]
self.logger.info(f'failed trampoline route: {r}')
if r not in failed_trampoline_routes:
failed_trampoline_routes.append(r)
else:
pass # maybe the route was reused between different MPP parts
continue
else:
raise PaymentFailure(failure_msg.code_name())
else: else:
self.handle_error_code_from_failed_htlc( self.handle_error_code_from_failed_htlc(
route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat) route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
@ -1654,18 +1702,9 @@ class LNWallet(LNWorker):
async def create_routes_for_payment( async def create_routes_for_payment(
self, *, self, *,
paysession: PaySession,
amount_msat: int, # part of payment amount we want routes for now amount_msat: int, # part of payment amount we want routes for now
final_total_msat: int, # total payment amount final receiver will get fwd_trampoline_onion: OnionPacket = None,
invoice_pubkey,
min_cltv_expiry,
r_tags,
invoice_features: int,
payment_hash: bytes,
payment_secret: bytes,
trampoline_fee_level: int,
failed_trampoline_routes: Iterable[Sequence[str]],
use_two_trampolines: bool,
fwd_trampoline_onion=None,
full_path: LNPaymentPath = None, full_path: LNPaymentPath = None,
channels: Optional[Sequence[Channel]] = None, channels: Optional[Sequence[Channel]] = None,
) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]: ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
@ -1675,7 +1714,6 @@ class LNWallet(LNWorker):
We first try to conduct the payment over a single channel. If that fails We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment.""" and mpp is supported by the receiver, we will split the payment."""
invoice_features = LnFeatures(invoice_features)
trampoline_features = LnFeatures.VAR_ONION_OPT trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
if channels: if channels:
@ -1688,15 +1726,15 @@ class LNWallet(LNWorker):
random.shuffle(my_active_channels) random.shuffle(my_active_channels)
split_configurations = self.suggest_splits( split_configurations = self.suggest_splits(
amount_msat=amount_msat, amount_msat=amount_msat,
final_total_msat=final_total_msat, final_total_msat=paysession.amount_to_pay,
my_active_channels=my_active_channels, my_active_channels=my_active_channels,
invoice_features=invoice_features, invoice_features=paysession.invoice_features,
r_tags=r_tags, r_tags=paysession.r_tags,
) )
for sc in split_configurations: for sc in split_configurations:
is_multichan_mpp = len(sc.config.items()) > 1 is_multichan_mpp = len(sc.config.items()) > 1
is_mpp = sum(len(x) for x in list(sc.config.values())) > 1 is_mpp = sum(len(x) for x in list(sc.config.values())) > 1
if is_mpp and not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
continue continue
if not is_mpp and self.config.TEST_FORCE_MPP: if not is_mpp and self.config.TEST_FORCE_MPP:
continue continue
@ -1715,33 +1753,33 @@ class LNWallet(LNWorker):
# for each trampoline forwarder, construct mpp trampoline # for each trampoline forwarder, construct mpp trampoline
for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items(): for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
per_trampoline_amount = sum([x[1] for x in trampoline_parts]) per_trampoline_amount = sum([x[1] for x in trampoline_parts])
if trampoline_node_id == invoice_pubkey: if trampoline_node_id == paysession.invoice_pubkey:
trampoline_route = None trampoline_route = None
trampoline_onion = None trampoline_onion = None
per_trampoline_secret = payment_secret per_trampoline_secret = paysession.payment_secret
per_trampoline_amount_with_fees = amount_msat per_trampoline_amount_with_fees = amount_msat
per_trampoline_cltv_delta = min_cltv_expiry per_trampoline_cltv_delta = paysession.min_cltv_expiry
per_trampoline_fees = 0 per_trampoline_fees = 0
else: else:
trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion( trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
amount_msat=per_trampoline_amount, amount_msat=per_trampoline_amount,
total_msat=final_total_msat, total_msat=paysession.amount_to_pay,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=paysession.min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey, my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey, invoice_pubkey=paysession.invoice_pubkey,
invoice_features=invoice_features, invoice_features=paysession.invoice_features,
node_id=trampoline_node_id, node_id=trampoline_node_id,
r_tags=r_tags, r_tags=paysession.r_tags,
payment_hash=payment_hash, payment_hash=paysession.payment_hash,
payment_secret=payment_secret, payment_secret=paysession.payment_secret,
local_height=local_height, local_height=local_height,
trampoline_fee_level=trampoline_fee_level, trampoline_fee_level=paysession.trampoline_fee_level,
use_two_trampolines=use_two_trampolines, use_two_trampolines=paysession.use_two_trampolines,
failed_routes=failed_trampoline_routes) failed_routes=paysession.failed_trampoline_routes)
# node_features is only used to determine is_tlv # node_features is only used to determine is_tlv
per_trampoline_secret = os.urandom(32) per_trampoline_secret = os.urandom(32)
per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
self.logger.info(f'created route with trampoline fee level={trampoline_fee_level}') self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}') self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
self.logger.info(f'per trampoline fees: {per_trampoline_fees}') self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
for chan_id, part_amount_msat in trampoline_parts: for chan_id, part_amount_msat in trampoline_parts:
@ -1764,7 +1802,7 @@ class LNWallet(LNWorker):
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
shi = SentHtlcInfo( shi = SentHtlcInfo(
route=route, route=route,
payment_secret_orig=payment_secret, payment_secret_orig=paysession.payment_secret,
payment_secret_bucket=per_trampoline_secret, payment_secret_bucket=per_trampoline_secret,
amount_msat=part_amount_msat_with_fees, amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees, bucket_msat=per_trampoline_amount_with_fees,
@ -1786,25 +1824,25 @@ class LNWallet(LNWorker):
partial( partial(
self.create_route_for_payment, self.create_route_for_payment,
amount_msat=part_amount_msat, amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey, invoice_pubkey=paysession.invoice_pubkey,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=paysession.min_cltv_expiry,
r_tags=r_tags, r_tags=paysession.r_tags,
invoice_features=invoice_features, invoice_features=paysession.invoice_features,
my_sending_channels=[channel] if is_multichan_mpp else my_active_channels, my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
full_path=full_path, full_path=full_path,
) )
) )
shi = SentHtlcInfo( shi = SentHtlcInfo(
route=route, route=route,
payment_secret_orig=payment_secret, payment_secret_orig=paysession.payment_secret,
payment_secret_bucket=payment_secret, payment_secret_bucket=paysession.payment_secret,
amount_msat=part_amount_msat, amount_msat=part_amount_msat,
bucket_msat=final_total_msat, bucket_msat=paysession.amount_to_pay,
amount_receiver_msat=part_amount_msat, amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None, trampoline_fee_level=None,
trampoline_route=None, trampoline_route=None,
) )
routes.append((shi, min_cltv_expiry, fwd_trampoline_onion)) routes.append((shi, paysession.min_cltv_expiry, fwd_trampoline_onion))
except NoPathFound: except NoPathFound:
continue continue
for route in routes: for route in routes:
@ -2159,7 +2197,9 @@ class LNWallet(LNWorker):
q = None q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_key = payment_hash + shi.payment_secret_orig payment_key = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_key) paysession = self._paysessions.get(payment_key)
if paysession:
q = paysession.sent_htlcs_q
if q: if q:
htlc_log = HtlcLog( htlc_log = HtlcLog(
success=True, success=True,
@ -2185,7 +2225,9 @@ class LNWallet(LNWorker):
q = None q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_okey = payment_hash + shi.payment_secret_orig payment_okey = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_okey) paysession = self._paysessions.get(payment_okey)
if paysession:
q = paysession.sent_htlcs_q
if q: if q:
# detect if it is part of a bucket # detect if it is part of a bucket
# if yes, wait until the bucket completely failed # if yes, wait until the bucket completely failed

35
electrum/tests/test_lnpeer.py

@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger from electrum.logging import console_stderr_handler, Logger
@ -168,7 +168,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.enable_htlc_settle = True self.enable_htlc_settle = True
self.enable_htlc_forwarding = True self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict() self.received_mpp_htlcs = dict()
self.sent_htlcs_q = defaultdict(asyncio.Queue) self._paysessions = dict()
self.sent_htlcs_info = dict() self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set) self.sent_buckets = defaultdict(set)
self.final_onion_forwardings = set() self.final_onion_forwardings = set()
@ -232,18 +232,22 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
await self.channel_db.stopped_event.wait() await self.channel_db.stopped_event.wait()
async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None): async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return [r async for r in self.create_routes_for_payment( paysession = PaySession(
amount_msat=amount_msat,
final_total_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0,
failed_trampoline_routes=[],
use_two_trampolines=False,
payment_hash=decoded_invoice.paymenthash, payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret, payment_secret=decoded_invoice.payment_secret,
initial_trampoline_fee_level=0,
invoice_features=decoded_invoice.get_features(),
r_tags=decoded_invoice.get_routing_info('r'),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
amount_to_pay=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
)
paysession.use_two_trampolines = False
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
self._paysessions[payment_key] = paysession
return [r async for r in self.create_routes_for_payment(
amount_msat=amount_msat,
paysession=paysession,
full_path=full_path)] full_path=full_path)]
get_payments = LNWallet.get_payments get_payments = LNWallet.get_payments
@ -854,9 +858,6 @@ class TestPeer(ElectrumTestCase):
_maybe_send_commitment2 = p2.maybe_send_commitment _maybe_send_commitment2 = p2.maybe_send_commitment
lnaddr2, pay_req2 = self.prepare_invoice(w2) lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = self.prepare_invoice(w1) lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret]
q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret]
# alice sends htlc BUT NOT COMMITMENT_SIGNED # alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
@ -901,9 +902,9 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment(alice_channel) p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel) p2.maybe_send_commitment(bob_channel)
htlc_log1 = await q1.get() htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get()
self.assertTrue(htlc_log1.success) self.assertTrue(htlc_log1.success)
htlc_log2 = await q2.get() htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get()
self.assertTrue(htlc_log2.success) self.assertTrue(htlc_log2.success)
raise PaymentDone() raise PaymentDone()

Loading…
Cancel
Save