diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index c0532800c..8c3c2bbb9 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1670,8 +1670,9 @@ class Peer(Logger): def maybe_forward_trampoline( self, *, - chan: Channel, - htlc: UpdateAddHtlc, + payment_hash: bytes, + cltv_expiry: int, + outer_onion: ProcessedOnionPacket, trampoline_onion: ProcessedOnionPacket): forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS @@ -1679,9 +1680,7 @@ class Peer(Logger): if not (forwarding_enabled and forwarding_trampoline_enabled): self.logger.info(f"trampoline forwarding is disabled. failing htlc.") raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'') - payload = trampoline_onion.hop_data.payload - payment_hash = htlc.payment_hash payment_data = payload.get('payment_data') if payment_data: # legacy case payment_secret = payment_data['payment_secret'] @@ -1709,8 +1708,10 @@ class Peer(Logger): # these are the fee/cltv paid by the sender # pay_to_node will raise if they are not sufficient - trampoline_cltv_delta = htlc.cltv_expiry - cltv_from_onion - trampoline_fee = htlc.amount_msat - amt_to_forward + trampoline_cltv_delta = cltv_expiry - cltv_from_onion + total_msat = outer_onion.hop_data.payload["payment_data"]["total_msat"] + trampoline_fee = total_msat - amt_to_forward + self.logger.info(f'trampoline cltv and fee: {trampoline_cltv_delta, trampoline_fee}') @log_exceptions async def forward_trampoline_payment(): @@ -1735,6 +1736,14 @@ class Peer(Logger): error_reason = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') self.lnworker.trampoline_forwarding_failures[payment_hash] = error_reason + # remove from list of payments, so that another attempt can be initiated + self.lnworker.trampoline_forwardings.remove(payment_hash) + + # add to list of ongoing payments + self.lnworker.trampoline_forwardings.add(payment_hash) + # clear previous failures + self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None) + # start payment asyncio.ensure_future(forward_trampoline_payment()) def maybe_fulfill_htlc( @@ -2335,12 +2344,13 @@ class Peer(Logger): chan=chan, htlc=htlc, processed_onion=processed_onion) + if trampoline_onion_packet: # trampoline- recipient or forwarding if not forwarding_info: trampoline_onion = self.process_onion_packet( trampoline_onion_packet, - payment_hash=htlc.payment_hash, + payment_hash=payment_hash, onion_packet_bytes=onion_packet_bytes, is_trampoline=True) if trampoline_onion.are_we_final: @@ -2354,16 +2364,24 @@ class Peer(Logger): # trampoline- HTLC we are supposed to forward, but haven't forwarded yet if not self.lnworker.enable_htlc_forwarding: return None, None, None + + if payment_hash in self.lnworker.trampoline_forwardings: + self.logger.info(f"we are already forwarding this.") + # we are already forwarding this payment + return None, True, None + self.maybe_forward_trampoline( - chan=chan, - htlc=htlc, + payment_hash=payment_hash, + cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts + outer_onion=processed_onion, trampoline_onion=trampoline_onion) # return True so that this code gets executed only once return None, True, None else: # trampoline- HTLC we are supposed to forward, and have already forwarded preimage = self.lnworker.get_preimage(payment_hash) - error_reason = self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None) + # get (and not pop) failure because the incoming payment might be multi-part + error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_hash) if error_reason: self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}') raise error_reason diff --git a/electrum/lnworker.py b/electrum/lnworker.py index b021e05cf..fcb142b72 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -682,6 +682,7 @@ class LNWallet(LNWorker): for payment_hash in self.get_payments(status='inflight').keys(): self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) + self.trampoline_forwardings = set() self.trampoline_forwarding_failures = {} # todo: should be persisted # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 490be6f9c..86c1c0e42 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -169,6 +169,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs_info = dict() self.sent_buckets = defaultdict(set) + self.trampoline_forwardings = set() self.trampoline_forwarding_failures = {} self.inflight_payments = set() self.preimages = {}