Browse Source

Trampoline MPP consolidation:

- fix parameters passed to maybe_forward_trampoline
 - use lnworker.trampoline_forwardings as a semaphore for ongoing
   trampoline payments
 - if a trampoline payment fails, fail all received HTLCs
master
ThomasV 2 years ago
parent
commit
e124ff7ee7
  1. 38
      electrum/lnpeer.py
  2. 1
      electrum/lnworker.py
  3. 1
      electrum/tests/test_lnpeer.py

38
electrum/lnpeer.py

@ -1670,8 +1670,9 @@ class Peer(Logger):
def maybe_forward_trampoline(
self, *,
chan: Channel,
htlc: UpdateAddHtlc,
payment_hash: bytes,
cltv_expiry: int,
outer_onion: ProcessedOnionPacket,
trampoline_onion: ProcessedOnionPacket):
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
@ -1679,9 +1680,7 @@ class Peer(Logger):
if not (forwarding_enabled and forwarding_trampoline_enabled):
self.logger.info(f"trampoline forwarding is disabled. failing htlc.")
raise OnionRoutingFailure(code=OnionFailureCode.PERMANENT_CHANNEL_FAILURE, data=b'')
payload = trampoline_onion.hop_data.payload
payment_hash = htlc.payment_hash
payment_data = payload.get('payment_data')
if payment_data: # legacy case
payment_secret = payment_data['payment_secret']
@ -1709,8 +1708,10 @@ class Peer(Logger):
# these are the fee/cltv paid by the sender
# pay_to_node will raise if they are not sufficient
trampoline_cltv_delta = htlc.cltv_expiry - cltv_from_onion
trampoline_fee = htlc.amount_msat - amt_to_forward
trampoline_cltv_delta = cltv_expiry - cltv_from_onion
total_msat = outer_onion.hop_data.payload["payment_data"]["total_msat"]
trampoline_fee = total_msat - amt_to_forward
self.logger.info(f'trampoline cltv and fee: {trampoline_cltv_delta, trampoline_fee}')
@log_exceptions
async def forward_trampoline_payment():
@ -1735,6 +1736,14 @@ class Peer(Logger):
error_reason = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
self.lnworker.trampoline_forwarding_failures[payment_hash] = error_reason
# remove from list of payments, so that another attempt can be initiated
self.lnworker.trampoline_forwardings.remove(payment_hash)
# add to list of ongoing payments
self.lnworker.trampoline_forwardings.add(payment_hash)
# clear previous failures
self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
# start payment
asyncio.ensure_future(forward_trampoline_payment())
def maybe_fulfill_htlc(
@ -2335,12 +2344,13 @@ class Peer(Logger):
chan=chan,
htlc=htlc,
processed_onion=processed_onion)
if trampoline_onion_packet:
# trampoline- recipient or forwarding
if not forwarding_info:
trampoline_onion = self.process_onion_packet(
trampoline_onion_packet,
payment_hash=htlc.payment_hash,
payment_hash=payment_hash,
onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
if trampoline_onion.are_we_final:
@ -2354,16 +2364,24 @@ class Peer(Logger):
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
if payment_hash in self.lnworker.trampoline_forwardings:
self.logger.info(f"we are already forwarding this.")
# we are already forwarding this payment
return None, True, None
self.maybe_forward_trampoline(
chan=chan,
htlc=htlc,
payment_hash=payment_hash,
cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts
outer_onion=processed_onion,
trampoline_onion=trampoline_onion)
# return True so that this code gets executed only once
return None, True, None
else:
# trampoline- HTLC we are supposed to forward, and have already forwarded
preimage = self.lnworker.get_preimage(payment_hash)
error_reason = self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
# get (and not pop) failure because the incoming payment might be multi-part
error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_hash)
if error_reason:
self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}')
raise error_reason

1
electrum/lnworker.py

@ -682,6 +682,7 @@ class LNWallet(LNWorker):
for payment_hash in self.get_payments(status='inflight').keys():
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
self.trampoline_forwardings = set()
self.trampoline_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]

1
electrum/tests/test_lnpeer.py

@ -169,6 +169,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.sent_htlcs = defaultdict(asyncio.Queue)
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.trampoline_forwardings = set()
self.trampoline_forwarding_failures = {}
self.inflight_payments = set()
self.preimages = {}

Loading…
Cancel
Save