diff --git a/electrum/lnworker.py b/electrum/lnworker.py index c2d8cc970..9f3eae07a 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -17,7 +17,7 @@ import socket import aiohttp import json from datetime import datetime, timezone -from functools import partial +from functools import partial, cached_property from collections import defaultdict import concurrent from concurrent import futures @@ -655,6 +655,105 @@ class LNGossip(LNWorker): self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}') +class PaySession(Logger): + def __init__( + self, + *, + payment_hash: bytes, + payment_secret: bytes, + initial_trampoline_fee_level: int, + invoice_features: int, + r_tags, + min_cltv_expiry: int, + amount_to_pay: int, # total payment amount final receiver will get + invoice_pubkey: bytes, + ): + assert payment_hash + assert payment_secret + self.payment_hash = payment_hash + self.payment_secret = payment_secret + self.payment_key = payment_hash + payment_secret + Logger.__init__(self) + + self.invoice_features = LnFeatures(invoice_features) + self.r_tags = r_tags + self.min_cltv_expiry = min_cltv_expiry + self.amount_to_pay = amount_to_pay + self.invoice_pubkey = invoice_pubkey + + self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog] + self.start_time = time.time() + + self.trampoline_fee_level = initial_trampoline_fee_level + self.failed_trampoline_routes = [] + self.use_two_trampolines = True + + self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) + self._nhtlcs_inflight = 0 + + def diagnostic_name(self): + pkey = sha256(self.payment_key) + return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}" + + def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog): + if htlc_log.trampoline_fee_level == self.trampoline_fee_level: + self.trampoline_fee_level += 1 + self.failed_trampoline_routes = [] + self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}') + else: + self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}') + + def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure): + # FIXME The trampoline nodes in the path are chosen randomly. + # Some of the errors might depend on how we have chosen them. + # Having more attempts is currently useful in part because of the randomness, + # instead we should give feedback to create_routes_for_payment. + # Sometimes the trampoline node fails to send a payment and returns + # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee. + if failure_msg.code in ( + OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT, + OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, + OnionFailureCode.TEMPORARY_CHANNEL_FAILURE): + # TODO: parse the node policy here (not returned by eclair yet) + # TODO: erring node is always the first trampoline even if second + # trampoline demands more fees, we can't influence this + self.maybe_raise_trampoline_fee(htlc_log) + elif self.use_two_trampolines: + self.use_two_trampolines = False + elif failure_msg.code in ( + OnionFailureCode.UNKNOWN_NEXT_PEER, + OnionFailureCode.TEMPORARY_NODE_FAILURE): + trampoline_route = htlc_log.route + r = [hop.end_node.hex() for hop in trampoline_route] + self.logger.info(f'failed trampoline route: {r}') + if r not in self.failed_trampoline_routes: + self.failed_trampoline_routes.append(r) + else: + pass # maybe the route was reused between different MPP parts + else: + raise PaymentFailure(failure_msg.code_name()) + + async def wait_for_one_htlc_to_resolve(self) -> HtlcLog: + self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}") + htlc_log = await self.sent_htlcs_q.get() + self._amount_inflight -= htlc_log.amount_msat + self._nhtlcs_inflight -= 1 + if self._amount_inflight < 0 or self._nhtlcs_inflight < 0: + raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !") + return htlc_log + + def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo: + self._nhtlcs_inflight += 1 + self._amount_inflight += sent_htlc_info.amount_receiver_msat + if self._amount_inflight > self.amount_to_pay: # safety belts + raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}") + sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level) + return sent_htlc_info + + def get_outstanding_amount_to_send(self) -> int: + return self.amount_to_pay - self._amount_inflight + + class LNWallet(LNWorker): lnwatcher: Optional['LNWalletWatcher'] @@ -694,9 +793,9 @@ class LNWallet(LNWorker): for channel_id, storage in channel_backups.items(): self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self) - self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] + self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] - self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) + self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus # detect inflight payments @@ -1274,9 +1373,9 @@ class LNWallet(LNWorker): invoice_features: int, attempts: int = None, full_path: LNPaymentPath = None, - fwd_trampoline_onion=None, - fwd_trampoline_fee=None, - fwd_trampoline_cltv_delta=None, + fwd_trampoline_onion: OnionPacket = None, + fwd_trampoline_fee: int = None, + fwd_trampoline_cltv_delta: int = None, channels: Optional[Sequence[Channel]] = None, ) -> None: @@ -1288,46 +1387,37 @@ class LNWallet(LNWorker): raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') payment_key = payment_hash + payment_secret + #assert payment_key not in self._paysessions # FIXME + self._paysessions[payment_key] = paysession = PaySession( + payment_hash=payment_hash, + payment_secret=payment_secret, + initial_trampoline_fee_level=self.INITIAL_TRAMPOLINE_FEE_LEVEL, + invoice_features=invoice_features, + r_tags=r_tags, + min_cltv_expiry=min_cltv_expiry, + amount_to_pay=amount_to_pay, + invoice_pubkey=node_pubkey, + ) self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding) # when encountering trampoline forwarding difficulties in the legacy case, we # sometimes need to fall back to a single trampoline forwarder, at the expense # of privacy - use_two_trampolines = True - trampoline_fee_level = self.INITIAL_TRAMPOLINE_FEE_LEVEL - failed_trampoline_routes = [] - start_time = time.time() - amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) - nhtlcs_inflight = 0 while True: - amount_to_send = amount_to_pay - amount_inflight - if amount_to_send > 0: + if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0: # 1. create a set of routes for remaining amount. # note: path-finding runs in a separate thread so that we don't block the asyncio loop # graph updates might occur during the computation routes = self.create_routes_for_payment( + paysession=paysession, amount_msat=amount_to_send, - final_total_msat=amount_to_pay, - invoice_pubkey=node_pubkey, - min_cltv_expiry=min_cltv_expiry, - r_tags=r_tags, - invoice_features=invoice_features, full_path=full_path, - payment_hash=payment_hash, - payment_secret=payment_secret, - trampoline_fee_level=trampoline_fee_level, - failed_trampoline_routes=failed_trampoline_routes, - use_two_trampolines=use_two_trampolines, fwd_trampoline_onion=fwd_trampoline_onion, channels=channels, ) # 2. send htlcs async for sent_htlc_info, cltv_delta, trampoline_onion in routes: - nhtlcs_inflight += 1 - amount_inflight += sent_htlc_info.amount_receiver_msat - if amount_inflight > amount_to_pay: # safety belts - raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}") - sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=trampoline_fee_level) + sent_htlc_info = paysession.add_new_htlc(sent_htlc_info) await self.pay_to_route( sent_htlc_info=sent_htlc_info, payment_hash=payment_hash, @@ -1339,12 +1429,7 @@ class LNWallet(LNWorker): # (e.g. attempt counter) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT) # 3. await a queue - self.logger.info(f"paysession for RHASH {payment_hash.hex()} waiting... {amount_inflight=}. {nhtlcs_inflight=}") - htlc_log = await self.sent_htlcs_q[payment_key].get() # TODO maybe wait a bit, more failures might come - amount_inflight -= htlc_log.amount_msat - nhtlcs_inflight -= 1 - if amount_inflight < 0 or nhtlcs_inflight < 0: - raise Exception(f"{amount_inflight=}, {nhtlcs_inflight=}. both should be >= 0 !") + htlc_log = await paysession.wait_for_one_htlc_to_resolve() # TODO maybe wait a bit, more failures might come log.append(htlc_log) if htlc_log.success: if self.network.path_finder: @@ -1357,7 +1442,7 @@ class LNWallet(LNWorker): self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False) return # htlc failed - if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - start_time > self.PAYMENT_TIMEOUT): + if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT): raise PaymentFailure('Giving up after %d attempts'%len(log)) # if we get a tmp channel failure, it might work to split the amount and try more routes # if we get a channel update, we might retry the same route and amount @@ -1373,45 +1458,8 @@ class LNWallet(LNWorker): raise PaymentFailure(failure_msg.code_name()) # trampoline if self.uses_trampoline(): - def maybe_raise_trampoline_fee(htlc_log): - nonlocal trampoline_fee_level - nonlocal failed_trampoline_routes - if htlc_log.trampoline_fee_level == trampoline_fee_level: - trampoline_fee_level += 1 - failed_trampoline_routes = [] - self.logger.info(f'raising trampoline fee level {trampoline_fee_level}') - else: - self.logger.info(f'NOT raising trampoline fee level, already at {trampoline_fee_level}') - # FIXME The trampoline nodes in the path are chosen randomly. - # Some of the errors might depend on how we have chosen them. - # Having more attempts is currently useful in part because of the randomness, - # instead we should give feedback to create_routes_for_payment. - # Sometimes the trampoline node fails to send a payment and returns - # TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee. - if code in ( - OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT, - OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, - OnionFailureCode.TEMPORARY_CHANNEL_FAILURE): - # TODO: parse the node policy here (not returned by eclair yet) - # TODO: erring node is always the first trampoline even if second - # trampoline demands more fees, we can't influence this - maybe_raise_trampoline_fee(htlc_log) - continue - elif use_two_trampolines: - use_two_trampolines = False - elif code in ( - OnionFailureCode.UNKNOWN_NEXT_PEER, - OnionFailureCode.TEMPORARY_NODE_FAILURE): - trampoline_route = htlc_log.route - r = [hop.end_node.hex() for hop in trampoline_route] - self.logger.info(f'failed trampoline route: {r}') - if r not in failed_trampoline_routes: - failed_trampoline_routes.append(r) - else: - pass # maybe the route was reused between different MPP parts - continue - else: - raise PaymentFailure(failure_msg.code_name()) + paysession.handle_failed_trampoline_htlc( + htlc_log=htlc_log, failure_msg=failure_msg) else: self.handle_error_code_from_failed_htlc( route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat) @@ -1654,18 +1702,9 @@ class LNWallet(LNWorker): async def create_routes_for_payment( self, *, + paysession: PaySession, amount_msat: int, # part of payment amount we want routes for now - final_total_msat: int, # total payment amount final receiver will get - invoice_pubkey, - min_cltv_expiry, - r_tags, - invoice_features: int, - payment_hash: bytes, - payment_secret: bytes, - trampoline_fee_level: int, - failed_trampoline_routes: Iterable[Sequence[str]], - use_two_trampolines: bool, - fwd_trampoline_onion=None, + fwd_trampoline_onion: OnionPacket = None, full_path: LNPaymentPath = None, channels: Optional[Sequence[Channel]] = None, ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]: @@ -1675,7 +1714,6 @@ class LNWallet(LNWorker): We first try to conduct the payment over a single channel. If that fails and mpp is supported by the receiver, we will split the payment.""" - invoice_features = LnFeatures(invoice_features) trampoline_features = LnFeatures.VAR_ONION_OPT local_height = self.network.get_local_height() if channels: @@ -1688,15 +1726,15 @@ class LNWallet(LNWorker): random.shuffle(my_active_channels) split_configurations = self.suggest_splits( amount_msat=amount_msat, - final_total_msat=final_total_msat, + final_total_msat=paysession.amount_to_pay, my_active_channels=my_active_channels, - invoice_features=invoice_features, - r_tags=r_tags, + invoice_features=paysession.invoice_features, + r_tags=paysession.r_tags, ) for sc in split_configurations: is_multichan_mpp = len(sc.config.items()) > 1 is_mpp = sum(len(x) for x in list(sc.config.values())) > 1 - if is_mpp and not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): + if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT): continue if not is_mpp and self.config.TEST_FORCE_MPP: continue @@ -1715,33 +1753,33 @@ class LNWallet(LNWorker): # for each trampoline forwarder, construct mpp trampoline for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items(): per_trampoline_amount = sum([x[1] for x in trampoline_parts]) - if trampoline_node_id == invoice_pubkey: + if trampoline_node_id == paysession.invoice_pubkey: trampoline_route = None trampoline_onion = None - per_trampoline_secret = payment_secret + per_trampoline_secret = paysession.payment_secret per_trampoline_amount_with_fees = amount_msat - per_trampoline_cltv_delta = min_cltv_expiry + per_trampoline_cltv_delta = paysession.min_cltv_expiry per_trampoline_fees = 0 else: trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion( amount_msat=per_trampoline_amount, - total_msat=final_total_msat, - min_cltv_expiry=min_cltv_expiry, + total_msat=paysession.amount_to_pay, + min_cltv_expiry=paysession.min_cltv_expiry, my_pubkey=self.node_keypair.pubkey, - invoice_pubkey=invoice_pubkey, - invoice_features=invoice_features, + invoice_pubkey=paysession.invoice_pubkey, + invoice_features=paysession.invoice_features, node_id=trampoline_node_id, - r_tags=r_tags, - payment_hash=payment_hash, - payment_secret=payment_secret, + r_tags=paysession.r_tags, + payment_hash=paysession.payment_hash, + payment_secret=paysession.payment_secret, local_height=local_height, - trampoline_fee_level=trampoline_fee_level, - use_two_trampolines=use_two_trampolines, - failed_routes=failed_trampoline_routes) + trampoline_fee_level=paysession.trampoline_fee_level, + use_two_trampolines=paysession.use_two_trampolines, + failed_routes=paysession.failed_trampoline_routes) # node_features is only used to determine is_tlv per_trampoline_secret = os.urandom(32) per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount - self.logger.info(f'created route with trampoline fee level={trampoline_fee_level}') + self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}') self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}') self.logger.info(f'per trampoline fees: {per_trampoline_fees}') for chan_id, part_amount_msat in trampoline_parts: @@ -1764,7 +1802,7 @@ class LNWallet(LNWorker): self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') shi = SentHtlcInfo( route=route, - payment_secret_orig=payment_secret, + payment_secret_orig=paysession.payment_secret, payment_secret_bucket=per_trampoline_secret, amount_msat=part_amount_msat_with_fees, bucket_msat=per_trampoline_amount_with_fees, @@ -1786,25 +1824,25 @@ class LNWallet(LNWorker): partial( self.create_route_for_payment, amount_msat=part_amount_msat, - invoice_pubkey=invoice_pubkey, - min_cltv_expiry=min_cltv_expiry, - r_tags=r_tags, - invoice_features=invoice_features, + invoice_pubkey=paysession.invoice_pubkey, + min_cltv_expiry=paysession.min_cltv_expiry, + r_tags=paysession.r_tags, + invoice_features=paysession.invoice_features, my_sending_channels=[channel] if is_multichan_mpp else my_active_channels, full_path=full_path, ) ) shi = SentHtlcInfo( route=route, - payment_secret_orig=payment_secret, - payment_secret_bucket=payment_secret, + payment_secret_orig=paysession.payment_secret, + payment_secret_bucket=paysession.payment_secret, amount_msat=part_amount_msat, - bucket_msat=final_total_msat, + bucket_msat=paysession.amount_to_pay, amount_receiver_msat=part_amount_msat, trampoline_fee_level=None, trampoline_route=None, ) - routes.append((shi, min_cltv_expiry, fwd_trampoline_onion)) + routes.append((shi, paysession.min_cltv_expiry, fwd_trampoline_onion)) except NoPathFound: continue for route in routes: @@ -2159,7 +2197,9 @@ class LNWallet(LNWorker): q = None if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): payment_key = payment_hash + shi.payment_secret_orig - q = self.sent_htlcs_q.get(payment_key) + paysession = self._paysessions.get(payment_key) + if paysession: + q = paysession.sent_htlcs_q if q: htlc_log = HtlcLog( success=True, @@ -2185,7 +2225,9 @@ class LNWallet(LNWorker): q = None if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): payment_okey = payment_hash + shi.payment_secret_orig - q = self.sent_htlcs_q.get(payment_okey) + paysession = self._paysessions.get(payment_okey) + if paysession: + q = paysession.sent_htlcs_q if q: # detect if it is part of a bucket # if yes, wait until the bucket completely failed diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index f1f69d583..ee6f970a7 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB -from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo +from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger @@ -168,7 +168,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.enable_htlc_settle = True self.enable_htlc_forwarding = True self.received_mpp_htlcs = dict() - self.sent_htlcs_q = defaultdict(asyncio.Queue) + self._paysessions = dict() self.sent_htlcs_info = dict() self.sent_buckets = defaultdict(set) self.final_onion_forwardings = set() @@ -232,18 +232,22 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): await self.channel_db.stopped_event.wait() async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None): - return [r async for r in self.create_routes_for_payment( - amount_msat=amount_msat, - final_total_msat=amount_msat, - invoice_pubkey=decoded_invoice.pubkey.serialize(), - min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), - r_tags=decoded_invoice.get_routing_info('r'), - invoice_features=decoded_invoice.get_features(), - trampoline_fee_level=0, - failed_trampoline_routes=[], - use_two_trampolines=False, + paysession = PaySession( payment_hash=decoded_invoice.paymenthash, payment_secret=decoded_invoice.payment_secret, + initial_trampoline_fee_level=0, + invoice_features=decoded_invoice.get_features(), + r_tags=decoded_invoice.get_routing_info('r'), + min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), + amount_to_pay=amount_msat, + invoice_pubkey=decoded_invoice.pubkey.serialize(), + ) + paysession.use_two_trampolines = False + payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret + self._paysessions[payment_key] = paysession + return [r async for r in self.create_routes_for_payment( + amount_msat=amount_msat, + paysession=paysession, full_path=full_path)] get_payments = LNWallet.get_payments @@ -854,9 +858,6 @@ class TestPeer(ElectrumTestCase): _maybe_send_commitment2 = p2.maybe_send_commitment lnaddr2, pay_req2 = self.prepare_invoice(w2) lnaddr1, pay_req1 = self.prepare_invoice(w1) - # create the htlc queues now (side-effecting defaultdict) - q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret] - q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret] # alice sends htlc BUT NOT COMMITMENT_SIGNED p1.maybe_send_commitment = lambda x: None route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route @@ -901,9 +902,9 @@ class TestPeer(ElectrumTestCase): p1.maybe_send_commitment(alice_channel) p2.maybe_send_commitment(bob_channel) - htlc_log1 = await q1.get() + htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get() self.assertTrue(htlc_log1.success) - htlc_log2 = await q2.get() + htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get() self.assertTrue(htlc_log2.success) raise PaymentDone()