Browse Source

Refactor trampoline forwarding and hold invoices.

- maybe_fulfill_htlc returns a forwarding callback that
   covers both cases.
 - previously, the callback of hold invoices was called as a
   side-effect of lnworker.check_mpp_status.
 - the same data structures (lnworker.trampoline_forwardings,
   lnworker.trampoline_forwarding_errors) are used for both
   trampoline forwardings and hold invoices.
 - maybe_fulfill_htlc still recursively calls itself to perform
   checks on trampoline onion. This is ugly, but ugliness is now
   contained to that method.
master
ThomasV 2 years ago
parent
commit
017186d107
  1. 161
      electrum/lnpeer.py
  2. 26
      electrum/lnworker.py
  3. 4
      electrum/tests/test_lnpeer.py

161
electrum/lnpeer.py

@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
import asyncio
import os
import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable
from datetime import datetime
import functools
@ -1668,7 +1668,8 @@ class Peer(Logger):
next_peer.maybe_send_commitment(next_chan)
return next_chan_scid, next_htlc.htlc_id
def maybe_forward_trampoline(
@log_exceptions
async def maybe_forward_trampoline(
self, *,
payment_hash: bytes,
cltv_expiry: int,
@ -1713,48 +1714,34 @@ class Peer(Logger):
trampoline_fee = total_msat - amt_to_forward
self.logger.info(f'trampoline cltv and fee: {trampoline_cltv_delta, trampoline_fee}')
@log_exceptions
async def forward_trampoline_payment():
try:
await self.lnworker.pay_to_node(
node_pubkey=outgoing_node_id,
payment_hash=payment_hash,
payment_secret=payment_secret,
amount_to_pay=amt_to_forward,
min_cltv_expiry=cltv_from_onion,
r_tags=[],
invoice_features=invoice_features,
fwd_trampoline_onion=next_trampoline_onion,
fwd_trampoline_fee=trampoline_fee,
fwd_trampoline_cltv_delta=trampoline_cltv_delta,
attempts=1)
except OnionRoutingFailure as e:
# FIXME: cannot use payment_hash as key
self.lnworker.trampoline_forwarding_failures[payment_hash] = e
except PaymentFailure as e:
# FIXME: adapt the error code
error_reason = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
self.lnworker.trampoline_forwarding_failures[payment_hash] = error_reason
# remove from list of payments, so that another attempt can be initiated
self.lnworker.trampoline_forwardings.remove(payment_hash)
# add to list of ongoing payments
self.lnworker.trampoline_forwardings.add(payment_hash)
# clear previous failures
self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
# start payment
asyncio.ensure_future(forward_trampoline_payment())
try:
await self.lnworker.pay_to_node(
node_pubkey=outgoing_node_id,
payment_hash=payment_hash,
payment_secret=payment_secret,
amount_to_pay=amt_to_forward,
min_cltv_expiry=cltv_from_onion,
r_tags=[],
invoice_features=invoice_features,
fwd_trampoline_onion=next_trampoline_onion,
fwd_trampoline_fee=trampoline_fee,
fwd_trampoline_cltv_delta=trampoline_cltv_delta,
attempts=1)
except OnionRoutingFailure as e:
raise
except PaymentFailure as e:
# FIXME: adapt the error code
raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
def maybe_fulfill_htlc(
self, *,
chan: Channel,
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket,
is_trampoline: bool = False) -> Optional[bytes]:
onion_packet_bytes: bytes,
is_trampoline: bool = False) -> Tuple[Optional[bytes], Optional[Callable]]:
"""As a final recipient of an HTLC, decide if we should fulfill it.
Return preimage or None
Return (preimage, forwarding_callback) with at most a single element not None
"""
def log_fail_reason(reason: str):
self.logger.info(f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. "
@ -1810,19 +1797,55 @@ class Peer(Logger):
log_fail_reason(f"'payment_secret' missing from onion")
raise exc_incorrect_or_unknown_pd
payment_status = self.lnworker.check_received_htlc(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
payment_status = self.lnworker.check_mpp_status(payment_secret_from_onion, chan.short_channel_id, htlc, total_msat)
if payment_status is None:
return None
return None, None
elif payment_status is False:
log_fail_reason(f"MPP_TIMEOUT")
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
else:
assert payment_status is True
payment_hash = htlc.payment_hash
preimage = self.lnworker.get_preimage(payment_hash)
hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash)
if not preimage and hold_invoice_callback:
if preimage:
return preimage, None
else:
# for hold invoices, trigger callback
cb, timeout = hold_invoice_callback
if int(time.time()) < timeout:
return None, lambda: cb(payment_hash)
else:
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again
if processed_onion.trampoline_onion_packet:
# TODO: we should check that all trampoline_onions are the same
return None
trampoline_onion = self.process_onion_packet(
processed_onion.trampoline_onion_packet,
payment_hash=payment_hash,
onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
if trampoline_onion.are_we_final:
# trampoline- we are final recipient of HTLC
preimage, cb = self.maybe_fulfill_htlc(
chan=chan,
htlc=htlc,
processed_onion=trampoline_onion,
onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
assert cb is None
return preimage, None
else:
callback = lambda: self.maybe_forward_trampoline(
payment_hash=payment_hash,
cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts
outer_onion=processed_onion,
trampoline_onion=trampoline_onion)
return None, callback
# TODO don't accept payments twice for same invoice
# TODO check invoice expiry
@ -1845,7 +1868,7 @@ class Peer(Logger):
if preimage:
self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}")
self.lnworker.set_request_status(htlc.payment_hash, PR_PAID)
return preimage
return preimage, None
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
@ -2340,42 +2363,36 @@ class Peer(Logger):
onion_packet_bytes=onion_packet_bytes)
if processed_onion.are_we_final:
# either we are final recipient; or if trampoline, see cases below
preimage = self.maybe_fulfill_htlc(
preimage, forwarding_callback = self.maybe_fulfill_htlc(
chan=chan,
htlc=htlc,
processed_onion=processed_onion)
processed_onion=processed_onion,
onion_packet_bytes=onion_packet_bytes)
if processed_onion.trampoline_onion_packet:
# trampoline- recipient or forwarding
if forwarding_callback:
if not forwarding_info:
trampoline_onion = self.process_onion_packet(
processed_onion.trampoline_onion_packet,
payment_hash=payment_hash,
onion_packet_bytes=onion_packet_bytes,
is_trampoline=True)
if trampoline_onion.are_we_final:
# trampoline- we are final recipient of HTLC
preimage = self.maybe_fulfill_htlc(
chan=chan,
htlc=htlc,
processed_onion=trampoline_onion,
is_trampoline=True)
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
pass
elif payment_hash in self.lnworker.trampoline_forwardings:
# we are already forwarding this payment
self.logger.info(f"we are already forwarding this.")
else:
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
if payment_hash in self.lnworker.trampoline_forwardings:
self.logger.info(f"we are already forwarding this.")
# we are already forwarding this payment
return None, True, None
self.maybe_forward_trampoline(
payment_hash=payment_hash,
cltv_expiry=htlc.cltv_expiry, # TODO: use max or enforce same value across mpp parts
outer_onion=processed_onion,
trampoline_onion=trampoline_onion)
# return True so that this code gets executed only once
# add to list of ongoing payments
self.lnworker.trampoline_forwardings.add(payment_hash)
# clear previous failures
self.lnworker.trampoline_forwarding_failures.pop(payment_hash, None)
async def wrapped_callback():
forwarding_coro = forwarding_callback()
try:
await forwarding_coro
except Exception as e:
# FIXME: cannot use payment_hash as key
self.lnworker.trampoline_forwarding_failures[payment_hash] = e
finally:
# remove from list of payments, so that another attempt can be initiated
self.lnworker.trampoline_forwardings.remove(payment_hash)
asyncio.ensure_future(wrapped_callback())
return None, True, None
else:
# trampoline- HTLC we are supposed to forward, and have already forwarded

26
electrum/lnworker.py

@ -1922,16 +1922,16 @@ class LNWallet(LNWorker):
if write_to_disk:
self.wallet.save_db()
def check_received_htlc(
self, payment_secret: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
def check_mpp_status(
self, payment_secret: bytes,
short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc,
expected_msat: int,
) -> Optional[bool]:
""" return MPP status: True (accepted), False (expired) or None (waiting)
"""
payment_hash = htlc.payment_hash
self.update_mpp_with_received_htlc(payment_secret, short_channel_id, htlc, expected_msat)
is_expired, is_accepted = self.get_mpp_status(payment_secret)
if not is_accepted and not is_expired:
@ -1944,19 +1944,7 @@ class LNWallet(LNWorker):
elif self.stopping_soon:
is_expired = True # try to time out pending HTLCs before shutting down
elif all([self.is_mpp_amount_reached(x) for x in payment_secrets]):
preimage = self.get_preimage(payment_hash)
hold_invoice_callback = self.hold_invoice_callbacks.get(payment_hash)
if not preimage and hold_invoice_callback:
# for hold invoices, trigger callback
cb, timeout = hold_invoice_callback
if int(time.time()) < timeout:
cb(payment_hash)
else:
is_expired = True
else:
# note: preimage will be None for outer trampoline onion
is_accepted = True
is_accepted = True
elif time.time() - first_timestamp > self.MPP_EXPIRY:
is_expired = True

4
electrum/tests/test_lnpeer.py

@ -251,7 +251,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
set_request_status = LNWallet.set_request_status
set_payment_status = LNWallet.set_payment_status
get_payment_status = LNWallet.get_payment_status
check_received_htlc = LNWallet.check_received_htlc
check_mpp_status = LNWallet.check_mpp_status
htlc_fulfilled = LNWallet.htlc_fulfilled
htlc_failed = LNWallet.htlc_failed
save_preimage = LNWallet.save_preimage
@ -764,7 +764,7 @@ class TestPeer(ElectrumTestCase):
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
def cb(payment_hash):
async def cb(payment_hash):
if not test_hold_timeout:
w2.save_preimage(payment_hash, preimage)
timeout = 1 if test_hold_timeout else 60

Loading…
Cancel
Save