From afac158c8004125e0b063d8bec03467ca5489a99 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Tue, 8 Aug 2023 16:28:20 +0000 Subject: [PATCH] lnworker: clean-up sent_htlcs_q and sent_htlcs_info - introduce SentHtlcInfo named tuple - some previously unnamed tuples are now much shorter: create_routes_for_payment no longer returns an 8-tuple! - sent_htlcs_q (renamed from sent_htlcs), is now keyed on payment_hash+payment_secret (needed for proper trampoline forwarding) --- electrum/lnpeer.py | 2 + electrum/lnworker.py | 137 +++++++++++++++++++++------------- electrum/tests/test_lnpeer.py | 85 +++++++++++---------- 3 files changed, 132 insertions(+), 92 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 11edfe81a..b5f7843ff 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1742,6 +1742,8 @@ class Peer(Logger): except OnionRoutingFailure as e: raise except PaymentFailure as e: + self.logger.debug( + f"maybe_forward_trampoline. PaymentFailure for {payment_hash.hex()=}, {payment_secret.hex()=}: {e!r}") # FIXME: adapt the error code raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 53ec4e3e3..bfc8cb5b0 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -69,7 +69,7 @@ from .lnutil import (Outpoint, LNPeerAddr, NoPathFound, InvalidGossipMsg) from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput -from .lnonion import OnionFailureCode, OnionRoutingFailure +from .lnonion import OnionFailureCode, OnionRoutingFailure, OnionPacket from .lnmsg import decode_msg from .i18n import _ from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_to_use, @@ -181,6 +181,20 @@ class ReceivedMPPStatus(NamedTuple): htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] +SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id + + +class SentHtlcInfo(NamedTuple): + route: LNPaymentRoute + payment_secret_orig: bytes + payment_secret_bucket: bytes + amount_msat: int + bucket_msat: int + amount_receiver_msat: int + trampoline_fee_level: Optional[int] + trampoline_route: Optional[LNPaymentRoute] + + class ErrorAddingPeer(Exception): pass @@ -678,8 +692,8 @@ class LNWallet(LNWorker): for channel_id, storage in channel_backups.items(): self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self) - self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] - self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level + self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] + self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus @@ -1268,7 +1282,8 @@ class LNWallet(LNWorker): if fwd_trampoline_cltv_delta < 576: raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') - self.logs[payment_hash.hex()] = log = [] + payment_key = payment_hash + payment_secret + self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding) # when encountering trampoline forwarding difficulties in the legacy case, we # sometimes need to fall back to a single trampoline forwarder, at the expense @@ -1300,28 +1315,24 @@ class LNWallet(LNWorker): channels=channels, ) # 2. send htlcs - async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion, trampoline_route in routes: - amount_inflight += amount_receiver_msat + async for sent_htlc_info, cltv_delta, trampoline_onion in routes: + amount_inflight += sent_htlc_info.amount_receiver_msat if amount_inflight > amount_to_pay: # safety belts raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}") + sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level) await self.pay_to_route( - route=route, - amount_msat=amount_msat, - total_msat=total_msat, - amount_receiver_msat=amount_receiver_msat, + sent_htlc_info=sent_htlc_info, payment_hash=payment_hash, - payment_secret=bucket_payment_secret, min_cltv_expiry=cltv_delta, trampoline_onion=trampoline_onion, - trampoline_fee_level=self.trampoline_fee_level, - trampoline_route=trampoline_route) + ) # invoice_status is triggered in self.set_invoice_status when it actally changes. # It is also triggered here to update progress for a lightning payment in the GUI # (e.g. attempt counter) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT) # 3. await a queue self.logger.info(f"amount inflight {amount_inflight}") - htlc_log = await self.sent_htlcs[payment_hash].get() + htlc_log = await self.sent_htlcs_q[payment_key].get() amount_inflight -= htlc_log.amount_msat if amount_inflight < 0: raise Exception(f"amount_inflight={amount_inflight} < 0") @@ -1394,48 +1405,44 @@ class LNWallet(LNWorker): async def pay_to_route( self, *, - route: LNPaymentRoute, - amount_msat: int, - total_msat: int, - amount_receiver_msat:int, + sent_htlc_info: SentHtlcInfo, payment_hash: bytes, - payment_secret: bytes, min_cltv_expiry: int, trampoline_onion: bytes = None, - trampoline_fee_level: int, - trampoline_route: Optional[List]) -> None: - - # send a single htlc - short_channel_id = route[0].short_channel_id + ) -> None: + """Sends a single HTLC.""" + shi = sent_htlc_info + del sent_htlc_info # just renamed + short_channel_id = shi.route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) assert chan, ShortChannelID(short_channel_id) - peer = self._peers.get(route[0].node_id) + peer = self._peers.get(shi.route[0].node_id) if not peer: raise PaymentFailure('Dropped peer') await peer.initialized htlc = peer.pay( - route=route, + route=shi.route, chan=chan, - amount_msat=amount_msat, - total_msat=total_msat, + amount_msat=shi.amount_msat, + total_msat=shi.bucket_msat, payment_hash=payment_hash, min_final_cltv_expiry=min_cltv_expiry, - payment_secret=payment_secret, + payment_secret=shi.payment_secret_bucket, trampoline_onion=trampoline_onion) key = (payment_hash, short_channel_id, htlc.htlc_id) - self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route - payment_key = payment_hash + payment_secret + self.sent_htlcs_info[key] = shi + payment_key = payment_hash + shi.payment_secret_bucket # if we sent MPP to a trampoline, add item to sent_buckets - if self.uses_trampoline() and amount_msat != total_msat: + if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat: if payment_key not in self.sent_buckets: self.sent_buckets[payment_key] = (0, 0) amount_sent, amount_failed = self.sent_buckets[payment_key] - amount_sent += amount_receiver_msat + amount_sent += shi.amount_receiver_msat self.sent_buckets[payment_key] = amount_sent, amount_failed if self.network.path_finder: # add inflight htlcs to liquidity hints - self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) + self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True) util.trigger_callback('htlc_added', chan, htlc, SENT) def handle_error_code_from_failed_htlc( @@ -1633,7 +1640,7 @@ class LNWallet(LNWorker): fwd_trampoline_onion=None, full_path: LNPaymentPath = None, channels: Optional[Sequence[Channel]] = None, - ) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]: + ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]: """Creates multiple routes for splitting a payment over the available private channels. @@ -1719,7 +1726,17 @@ class LNWallet(LNWorker): node_features=trampoline_features) ] self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') - routes.append((route, part_amount_msat_with_fees, per_trampoline_amount_with_fees, part_amount_msat, per_trampoline_cltv_delta, per_trampoline_secret, trampoline_onion, trampoline_route)) + shi = SentHtlcInfo( + route=route, + payment_secret_orig=payment_secret, + payment_secret_bucket=per_trampoline_secret, + amount_msat=part_amount_msat_with_fees, + bucket_msat=per_trampoline_amount_with_fees, + amount_receiver_msat=part_amount_msat, + trampoline_fee_level=None, + trampoline_route=trampoline_route, + ) + routes.append((shi, per_trampoline_cltv_delta, trampoline_onion)) if per_trampoline_fees != 0: self.logger.info('not enough margin to pay trampoline fee') raise NoPathFound() @@ -1741,7 +1758,17 @@ class LNWallet(LNWorker): full_path=full_path, ) ) - routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion, None)) + shi = SentHtlcInfo( + route=route, + payment_secret_orig=payment_secret, + payment_secret_bucket=payment_secret, + amount_msat=part_amount_msat, + bucket_msat=final_total_msat, + amount_receiver_msat=part_amount_msat, + trampoline_fee_level=None, + trampoline_route=None, + ) + routes.append((shi, min_cltv_expiry, fwd_trampoline_onion)) except NoPathFound: continue for route in routes: @@ -2096,14 +2123,16 @@ class LNWallet(LNWorker): def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int): util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id) self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) - q = self.sent_htlcs.get(payment_hash) + q = None + if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): + payment_key = payment_hash + shi.payment_secret_orig + q = self.sent_htlcs_q.get(payment_key) if q: - route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)] htlc_log = HtlcLog( success=True, - route=route, - amount_msat=amount_receiver_msat, - trampoline_fee_level=trampoline_fee_level) + route=shi.route, + amount_msat=shi.amount_receiver_msat, + trampoline_fee_level=shi.trampoline_fee_level) q.put_nowait(htlc_log) else: key = payment_hash.hex() @@ -2120,12 +2149,16 @@ class LNWallet(LNWorker): util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) - q = self.sent_htlcs.get(payment_hash) + q = None + if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): + payment_okey = payment_hash + shi.payment_secret_orig + q = self.sent_htlcs_q.get(payment_okey) if q: # detect if it is part of a bucket # if yes, wait until the bucket completely failed - key = (payment_hash, chan.short_channel_id, htlc_id) - route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[key] + shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)] + amount_receiver_msat = shi.amount_receiver_msat + route = shi.route if error_bytes: # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? try: @@ -2140,19 +2173,19 @@ class LNWallet(LNWorker): self.logger.info(f"htlc_failed {failure_message}") # check sent_buckets if we use trampoline - payment_key = payment_hash + payment_secret - if self.uses_trampoline() and payment_key in self.sent_buckets: - amount_sent, amount_failed = self.sent_buckets[payment_key] + payment_bkey = payment_hash + shi.payment_secret_bucket + if self.uses_trampoline() and payment_bkey in self.sent_buckets: + amount_sent, amount_failed = self.sent_buckets[payment_bkey] amount_failed += amount_receiver_msat - self.sent_buckets[payment_key] = amount_sent, amount_failed + self.sent_buckets[payment_bkey] = amount_sent, amount_failed if amount_sent != amount_failed: self.logger.info('bucket still active...') return self.logger.info('bucket failed') amount_receiver_msat = amount_sent - if trampoline_route: - route = trampoline_route + if shi.trampoline_route: + route = shi.trampoline_route htlc_log = HtlcLog( success=False, route=route, @@ -2160,7 +2193,7 @@ class LNWallet(LNWorker): error_bytes=error_bytes, failure_msg=failure_message, sender_idx=sender_idx, - trampoline_fee_level=trampoline_fee_level) + trampoline_fee_level=shi.trampoline_fee_level) q.put_nowait(htlc_log) else: self.logger.info(f"received unknown htlc_failed, probably from previous session") diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 8c7ecbeee..0bd463fcc 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB -from electrum.lnworker import LNWallet, NoPathFound +from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo from electrum.lnmsg import encode_msg, decode_msg from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger @@ -166,7 +166,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.enable_htlc_settle = True self.enable_htlc_forwarding = True self.received_mpp_htlcs = dict() - self.sent_htlcs = defaultdict(asyncio.Queue) + self.sent_htlcs_q = defaultdict(asyncio.Queue) self.sent_htlcs_info = dict() self.sent_buckets = defaultdict(set) self.trampoline_forwardings = set() @@ -740,7 +740,7 @@ class TestPeer(ElectrumTestCase): with self.assertRaises(SuccessfulTest): await f() - async def _activate_trampoline(self, w): + async def _activate_trampoline(self, w: MockLNWallet): if w.network.channel_db: w.network.channel_db.stop() await w.network.channel_db.stopped_event.wait() @@ -837,38 +837,44 @@ class TestPeer(ElectrumTestCase): lnaddr2, pay_req2 = self.prepare_invoice(w2) lnaddr1, pay_req1 = self.prepare_invoice(w1) # create the htlc queues now (side-effecting defaultdict) - q1 = w1.sent_htlcs[lnaddr2.paymenthash] - q2 = w2.sent_htlcs[lnaddr1.paymenthash] + q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret] + q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret] # alice sends htlc BUT NOT COMMITMENT_SIGNED p1.maybe_send_commitment = lambda x: None - route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0] - amount_msat = lnaddr2.get_amount_msat() - await w1.pay_to_route( + route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route + shi1 = SentHtlcInfo( route=route1, - amount_msat=amount_msat, - total_msat=amount_msat, - amount_receiver_msat=amount_msat, + payment_secret_orig=lnaddr2.payment_secret, + payment_secret_bucket=lnaddr2.payment_secret, + amount_msat=lnaddr2.get_amount_msat(), + bucket_msat=lnaddr2.get_amount_msat(), + amount_receiver_msat=lnaddr2.get_amount_msat(), + trampoline_fee_level=None, + trampoline_route=None, + ) + await w1.pay_to_route( + sent_htlc_info=shi1, payment_hash=lnaddr2.paymenthash, min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), - payment_secret=lnaddr2.payment_secret, - trampoline_fee_level=0, - trampoline_route=None, ) p1.maybe_send_commitment = _maybe_send_commitment1 # bob sends htlc BUT NOT COMMITMENT_SIGNED p2.maybe_send_commitment = lambda x: None - route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0] - amount_msat = lnaddr1.get_amount_msat() - await w2.pay_to_route( + route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route + shi2 = SentHtlcInfo( route=route2, - amount_msat=amount_msat, - total_msat=amount_msat, - amount_receiver_msat=amount_msat, + payment_secret_orig=lnaddr1.payment_secret, + payment_secret_bucket=lnaddr1.payment_secret, + amount_msat=lnaddr1.get_amount_msat(), + bucket_msat=lnaddr1.get_amount_msat(), + amount_receiver_msat=lnaddr1.get_amount_msat(), + trampoline_fee_level=None, + trampoline_route=None, + ) + await w2.pay_to_route( + sent_htlc_info=shi2, payment_hash=lnaddr1.paymenthash, min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), - payment_secret=lnaddr1.payment_secret, - trampoline_fee_level=0, - trampoline_route=None, ) p2.maybe_send_commitment = _maybe_send_commitment2 # sleep a bit so that they both receive msgs sent so far @@ -878,9 +884,9 @@ class TestPeer(ElectrumTestCase): p2.maybe_send_commitment(bob_channel) htlc_log1 = await q1.get() - assert htlc_log1.success + self.assertTrue(htlc_log1.success) htlc_log2 = await q2.get() - assert htlc_log2.success + self.assertTrue(htlc_log2.success) raise PaymentDone() async def f(): @@ -1184,10 +1190,7 @@ class TestPeer(ElectrumTestCase): if not bob_forwarding: graph.workers['bob'].enable_htlc_forwarding = False if alice_uses_trampoline: - if graph.workers['alice'].network.channel_db: - graph.workers['alice'].network.channel_db.stop() - await graph.workers['alice'].network.channel_db.stopped_event.wait() - graph.workers['alice'].network.channel_db = None + await self._activate_trampoline(graph.workers['alice']) else: assert graph.workers['alice'].network.channel_db is not None lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay) @@ -1433,7 +1436,7 @@ class TestPeer(ElectrumTestCase): await util.wait_for2(p1.initialized, 1) await util.wait_for2(p2.initialized, 1) # alice sends htlc - route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] + route = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0].route p1.pay(route=route, chan=alice_channel, amount_msat=lnaddr.get_amount_msat(), @@ -1556,7 +1559,8 @@ class TestPeer(ElectrumTestCase): lnaddr, pay_req = self.prepare_invoice(w2) lnaddr = w1._check_invoice(pay_req) - route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] + shi = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0] + route, amount_msat = shi.route, shi.amount_msat assert amount_msat == lnaddr.get_amount_msat() await w1.force_close_channel(alice_channel.channel_id) @@ -1570,20 +1574,21 @@ class TestPeer(ElectrumTestCase): # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed async def f(): - min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() - payment_hash = lnaddr.paymenthash - payment_secret = lnaddr.payment_secret - pay = w1.pay_to_route( + shi = SentHtlcInfo( route=route, + payment_secret_orig=lnaddr.payment_secret, + payment_secret_bucket=lnaddr.payment_secret, amount_msat=amount_msat, - total_msat=amount_msat, + bucket_msat=amount_msat, amount_receiver_msat=amount_msat, - payment_hash=payment_hash, - payment_secret=payment_secret, - min_cltv_expiry=min_cltv_expiry, - trampoline_fee_level=0, + trampoline_fee_level=None, trampoline_route=None, ) + pay = w1.pay_to_route( + sent_htlc_info=shi, + payment_hash=lnaddr.paymenthash, + min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), + ) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(PaymentFailure): await f()