Browse Source

Refactor payment forwarding:

- all forwarding types use the same flow
 - forwarding callback returns a htlc_key or None
 - forwarding info is persisted in lnworker:
   - ongoing_forwardings
   - downstream to upstream htlc_key
   - htlc_key -> error_bytes
master
ThomasV 2 years ago
parent
commit
9b1c40e396
  1. 19
      electrum/lnchannel.py
  2. 225
      electrum/lnpeer.py
  3. 7
      electrum/lnutil.py
  4. 86
      electrum/lnworker.py
  5. 4
      electrum/submarine_swaps.py
  6. 14
      electrum/tests/test_lnpeer.py

19
electrum/lnchannel.py

@ -657,7 +657,6 @@ class Channel(AbstractChannel):
self.onion_keys = state['onion_keys'] # type: Dict[int, bytes] self.onion_keys = state['onion_keys'] # type: Dict[int, bytes]
self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp'] self.data_loss_protect_remote_pcp = state['data_loss_protect_remote_pcp']
self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate) self.hm = HTLCManager(log=state['log'], initial_feerate=initial_feerate)
self.fail_htlc_reasons = state["fail_htlc_reasons"]
self.unfulfilled_htlcs = state["unfulfilled_htlcs"] self.unfulfilled_htlcs = state["unfulfilled_htlcs"]
self._state = ChannelState[state['state']] self._state = ChannelState[state['state']]
self.peer_state = PeerState.DISCONNECTED self.peer_state = PeerState.DISCONNECTED
@ -1222,26 +1221,8 @@ class Channel(AbstractChannel):
error_bytes, failure_message = self._receive_fail_reasons.pop(htlc.htlc_id) error_bytes, failure_message = self._receive_fail_reasons.pop(htlc.htlc_id)
except KeyError: except KeyError:
error_bytes, failure_message = None, None error_bytes, failure_message = None, None
# if we are forwarding, save error message to disk
if self.lnworker.get_payment_info(htlc.payment_hash) is None:
self.save_fail_htlc_reason(htlc.htlc_id, error_bytes, failure_message)
self.lnworker.htlc_failed(self, htlc.payment_hash, htlc.htlc_id, error_bytes, failure_message) self.lnworker.htlc_failed(self, htlc.payment_hash, htlc.htlc_id, error_bytes, failure_message)
def save_fail_htlc_reason(
self,
htlc_id: int,
error_bytes: Optional[bytes],
failure_message: Optional['OnionRoutingFailure']):
error_hex = error_bytes.hex() if error_bytes else None
failure_hex = failure_message.to_bytes().hex() if failure_message else None
self.fail_htlc_reasons[htlc_id] = (error_hex, failure_hex)
def pop_fail_htlc_reason(self, htlc_id):
error_hex, failure_hex = self.fail_htlc_reasons.pop(htlc_id, (None, None))
error_bytes = bytes.fromhex(error_hex) if error_hex else None
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
return error_bytes, failure_message
def extract_preimage_from_htlc_txin(self, txin: TxInput) -> None: def extract_preimage_from_htlc_txin(self, txin: TxInput) -> None:
witness = txin.witness_elements() witness = txin.witness_elements()
if len(witness) == 5: # HTLC success tx if len(witness) == 5: # HTLC success tx

225
electrum/lnpeer.py

@ -45,6 +45,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, ChannelConf
IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage, IncompatibleLightningFeatures, derive_payment_secret_from_payment_preimage,
ChannelType, LNProtocolWarning, validate_features, IncompatibleOrInsaneFeatures) ChannelType, LNProtocolWarning, validate_features, IncompatibleOrInsaneFeatures)
from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget from .lnutil import FeeUpdate, channel_id_from_funding_tx, PaymentFeeBudget
from .lnutil import serialize_htlc_key
from .lntransport import LNTransport, LNTransportBase from .lntransport import LNTransport, LNTransportBase
from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType, FailedToParseMsg
from .interface import GracefulDisconnect from .interface import GracefulDisconnect
@ -120,7 +121,6 @@ class Peer(Logger):
self._received_revack_event = asyncio.Event() self._received_revack_event = asyncio.Event()
self.received_commitsig_event = asyncio.Event() self.received_commitsig_event = asyncio.Event()
self.downstream_htlc_resolved_event = asyncio.Event() self.downstream_htlc_resolved_event = asyncio.Event()
self.jit_failures = {}
def send_message(self, message_name: str, **kwargs): def send_message(self, message_name: str, **kwargs):
assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!" assert util.get_running_loop() == util.get_asyncio_loop(), f"this must be run on the asyncio thread!"
@ -1688,7 +1688,8 @@ class Peer(Logger):
chan.receive_htlc(htlc, onion_packet) chan.receive_htlc(htlc, onion_packet)
util.trigger_callback('htlc_added', chan, htlc, RECEIVED) util.trigger_callback('htlc_added', chan, htlc, RECEIVED)
def maybe_forward_htlc(
async def maybe_forward_htlc(
self, *, self, *,
incoming_chan: Channel, incoming_chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
@ -1742,20 +1743,12 @@ class Peer(Logger):
if next_chan.can_pay(next_amount_msat_htlc): if next_chan.can_pay(next_amount_msat_htlc):
break break
else: else:
async def wrapped_callback(): return await self.lnworker.open_channel_just_in_time(
coro = self.lnworker.open_channel_just_in_time( next_peer,
next_peer, next_amount_msat_htlc,
next_amount_msat_htlc, next_cltv_abs,
next_cltv_abs, htlc.payment_hash,
htlc.payment_hash, processed_onion.next_packet)
processed_onion.next_packet)
try:
await coro
except OnionRoutingFailure as e:
self.jit_failures[next_chan_scid.hex()] = e
asyncio.ensure_future(wrapped_callback())
return next_chan_scid, -1
local_height = chain.height() local_height = chain.height()
if next_chan is None: if next_chan is None:
@ -1812,7 +1805,8 @@ class Peer(Logger):
except BaseException as e: except BaseException as e:
log_fail_reason(f"error sending message to next_peer={next_chan.node_id.hex()}") log_fail_reason(f"error sending message to next_peer={next_chan.node_id.hex()}")
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message) raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_CHANNEL_FAILURE, data=outgoing_chan_upd_message)
return next_chan_scid, next_htlc.htlc_id htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), next_htlc.htlc_id)
return htlc_key
@log_exceptions @log_exceptions
async def maybe_forward_trampoline( async def maybe_forward_trampoline(
@ -1961,10 +1955,25 @@ class Peer(Logger):
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket, processed_onion: ProcessedOnionPacket,
onion_packet_bytes: bytes, onion_packet_bytes: bytes,
already_forwarded = False,
) -> Tuple[Optional[bytes], Optional[Callable]]: ) -> Tuple[Optional[bytes], Optional[Callable]]:
"""As a final recipient of an HTLC, decide if we should fulfill it.
Return (preimage, forwarding_callback) with at most a single element not None
""" """
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
Return (payment_key, preimage, callback) with at most a single element of the last two not None
Side effect: populates lnworker.received_mpp (which is not persisted, needs to be re-populated after restart)
"""
if not processed_onion.are_we_final:
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
# use the htlc key if we are forwarding
payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
callback = lambda: self.maybe_forward_htlc(
incoming_chan=chan,
htlc=htlc,
processed_onion=processed_onion)
return payment_key, None, callback
def log_fail_reason(reason: str): def log_fail_reason(reason: str):
self.logger.info( self.logger.info(
f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. " f"maybe_fulfill_htlc. will FAIL HTLC: chan {chan.short_channel_id}. "
@ -1986,9 +1995,6 @@ class Peer(Logger):
exc_incorrect_or_unknown_pd = OnionRoutingFailure( exc_incorrect_or_unknown_pd = OnionRoutingFailure(
code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS,
data=amt_to_forward.to_bytes(8, byteorder="big") + local_height.to_bytes(4, byteorder="big")) data=amt_to_forward.to_bytes(8, byteorder="big") + local_height.to_bytes(4, byteorder="big"))
if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs:
log_fail_reason(f"htlc.cltv_abs is unreasonably close")
raise exc_incorrect_or_unknown_pd
try: try:
cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"] cltv_abs_from_onion = processed_onion.hop_data.payload["outgoing_cltv_value"]["outgoing_cltv_value"]
except Exception: except Exception:
@ -2025,15 +2031,19 @@ class Peer(Logger):
log_fail_reason(f"'payment_secret' missing from onion") log_fail_reason(f"'payment_secret' missing from onion")
raise exc_incorrect_or_unknown_pd raise exc_incorrect_or_unknown_pd
# payment key for final onions
payment_hash = htlc.payment_hash
payment_key = (payment_hash + payment_secret_from_onion).hex()
from .lnworker import RecvMPPResolution from .lnworker import RecvMPPResolution
mpp_resolution = self.lnworker.check_mpp_status( mpp_resolution = self.lnworker.check_mpp_status(
payment_secret=payment_secret_from_onion, payment_secret=payment_secret_from_onion,
short_channel_id=chan.short_channel_id, short_channel_id=chan.get_scid_or_local_alias(),
htlc=htlc, htlc=htlc,
expected_msat=total_msat, expected_msat=total_msat,
) )
if mpp_resolution == RecvMPPResolution.WAITING: if mpp_resolution == RecvMPPResolution.WAITING:
return None, None return payment_key, None, None
elif mpp_resolution == RecvMPPResolution.EXPIRED: elif mpp_resolution == RecvMPPResolution.EXPIRED:
log_fail_reason(f"MPP_TIMEOUT") log_fail_reason(f"MPP_TIMEOUT")
raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.MPP_TIMEOUT, data=b'')
@ -2045,7 +2055,12 @@ class Peer(Logger):
else: else:
raise Exception(f"unexpected {mpp_resolution=}") raise Exception(f"unexpected {mpp_resolution=}")
payment_hash = htlc.payment_hash if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs:
if not already_forwarded:
log_fail_reason(f"htlc.cltv_abs is unreasonably close")
raise exc_incorrect_or_unknown_pd
else:
return payment_key, None, None
# detect callback # detect callback
# if there is a trampoline_onion, maybe_fulfill_htlc will be called again # if there is a trampoline_onion, maybe_fulfill_htlc will be called again
@ -2060,23 +2075,19 @@ class Peer(Logger):
is_trampoline=True) is_trampoline=True)
if trampoline_onion.are_we_final: if trampoline_onion.are_we_final:
# trampoline- we are final recipient of HTLC # trampoline- we are final recipient of HTLC
preimage, cb = self.maybe_fulfill_htlc( return self.maybe_fulfill_htlc(
chan=chan, chan=chan,
htlc=htlc, htlc=htlc,
processed_onion=trampoline_onion, processed_onion=trampoline_onion,
onion_packet_bytes=onion_packet_bytes, onion_packet_bytes=onion_packet_bytes,
) )
if preimage:
return preimage, None
else:
return None, cb
else: else:
callback = lambda: self.maybe_forward_trampoline( callback = lambda: self.maybe_forward_trampoline(
payment_hash=payment_hash, payment_hash=payment_hash,
inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts inc_cltv_abs=htlc.cltv_abs, # TODO: use max or enforce same value across mpp parts
outer_onion=processed_onion, outer_onion=processed_onion,
trampoline_onion=trampoline_onion) trampoline_onion=trampoline_onion)
return None, callback return payment_key, None, callback
# TODO don't accept payments twice for same invoice # TODO don't accept payments twice for same invoice
# TODO check invoice expiry # TODO check invoice expiry
@ -2102,15 +2113,19 @@ class Peer(Logger):
hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash) hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash)
if hold_invoice_callback and not preimage: if hold_invoice_callback and not preimage:
return None, lambda: hold_invoice_callback(payment_hash) callback = lambda: hold_invoice_callback(payment_hash)
return payment_key, None, callback
if not preimage: if not preimage:
self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}") if not already_forwarded:
raise exc_incorrect_or_unknown_pd log_fail_reason(f"missing preimage and no hold invoice callback {payment_hash.hex()}")
raise exc_incorrect_or_unknown_pd
else:
return payment_key, None, None
chan.opening_fee = None chan.opening_fee = None
self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}") self.logger.info(f"maybe_fulfill_htlc. will FULFILL HTLC: chan {chan.short_channel_id}. htlc={str(htlc)}")
return preimage, None return payment_key, preimage, None
def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes): def fulfill_htlc(self, chan: Channel, htlc_id: int, preimage: bytes):
self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
@ -2512,17 +2527,13 @@ class Peer(Logger):
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
done = set() done = set()
unfulfilled = chan.unfulfilled_htlcs unfulfilled = chan.unfulfilled_htlcs
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_info) in unfulfilled.items(): for htlc_id, (local_ctn, remote_ctn, onion_packet_hex, forwarding_key) in unfulfilled.items():
if forwarding_info:
forwarding_info = tuple(forwarding_info) # storage converts to list
self.lnworker.downstream_htlc_to_upstream_peer_map[forwarding_info] = self.pubkey
if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id): if not chan.hm.is_htlc_irrevocably_added_yet(htlc_proposer=REMOTE, htlc_id=htlc_id):
continue continue
htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id) htlc = chan.hm.get_htlc_by_id(REMOTE, htlc_id)
error_reason = None # type: Optional[OnionRoutingFailure] error_reason = None # type: Optional[OnionRoutingFailure]
error_bytes = None # type: Optional[bytes] error_bytes = None # type: Optional[bytes]
preimage = None preimage = None
fw_info = None
onion_packet_bytes = bytes.fromhex(onion_packet_hex) onion_packet_bytes = bytes.fromhex(onion_packet_hex)
onion_packet = None onion_packet = None
try: try:
@ -2531,20 +2542,21 @@ class Peer(Logger):
error_reason = e error_reason = e
else: else:
try: try:
preimage, fw_info, error_bytes = self.process_unfulfilled_htlc( preimage, _forwarding_key, error_bytes = self.process_unfulfilled_htlc(
chan=chan, chan=chan,
htlc=htlc, htlc=htlc,
forwarding_info=forwarding_info, forwarding_key=forwarding_key,
onion_packet_bytes=onion_packet_bytes, onion_packet_bytes=onion_packet_bytes,
onion_packet=onion_packet) onion_packet=onion_packet)
if _forwarding_key:
assert forwarding_key is False
unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, _forwarding_key
except OnionRoutingFailure as e: except OnionRoutingFailure as e:
error_bytes = construct_onion_error(e, onion_packet.public_key, our_onion_private_key=self.privkey) error_bytes = construct_onion_error(e, onion_packet.public_key, our_onion_private_key=self.privkey)
if error_bytes: if error_bytes:
error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey) error_bytes = obfuscate_onion_error(error_bytes, onion_packet.public_key, our_onion_private_key=self.privkey)
if fw_info:
unfulfilled[htlc_id] = local_ctn, remote_ctn, onion_packet_hex, fw_info if preimage or error_reason or error_bytes:
self.lnworker.downstream_htlc_to_upstream_peer_map[fw_info] = self.pubkey
elif preimage or error_reason or error_bytes:
if preimage: if preimage:
self.lnworker.set_request_status(htlc.payment_hash, PR_PAID) self.lnworker.set_request_status(htlc.payment_hash, PR_PAID)
if not self.lnworker.enable_htlc_settle: if not self.lnworker.enable_htlc_settle:
@ -2563,10 +2575,7 @@ class Peer(Logger):
done.add(htlc_id) done.add(htlc_id)
# cleanup # cleanup
for htlc_id in done: for htlc_id in done:
local_ctn, remote_ctn, onion_packet_hex, forwarding_info = unfulfilled.pop(htlc_id) unfulfilled.pop(htlc_id)
if forwarding_info:
forwarding_info = tuple(forwarding_info) # storage converts to list
self.lnworker.downstream_htlc_to_upstream_peer_map.pop(forwarding_info, None)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
def _maybe_cleanup_received_htlcs_pending_removal(self) -> None: def _maybe_cleanup_received_htlcs_pending_removal(self) -> None:
@ -2596,11 +2605,11 @@ class Peer(Logger):
self, *, self, *,
chan: Channel, chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
forwarding_info: Tuple[str, int], forwarding_key: str,
onion_packet_bytes: bytes, onion_packet_bytes: bytes,
onion_packet: OnionPacket) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]: onion_packet: OnionPacket) -> Tuple[Optional[bytes], Union[bool, None, Tuple[str, int]], Optional[bytes]]:
""" """
return (preimage, fw_info, error_bytes) with at most a single element that is not None return (preimage, payment_key, error_bytes) with at most a single element that is not None
raise an OnionRoutingFailure if we need to fail the htlc raise an OnionRoutingFailure if we need to fail the htlc
""" """
payment_hash = htlc.payment_hash payment_hash = htlc.payment_hash
@ -2608,80 +2617,50 @@ class Peer(Logger):
onion_packet, onion_packet,
payment_hash=payment_hash, payment_hash=payment_hash,
onion_packet_bytes=onion_packet_bytes) onion_packet_bytes=onion_packet_bytes)
if processed_onion.are_we_final:
# either we are final recipient; or if trampoline, see cases below htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
if not forwarding_info: error_bytes = error_reason = None
preimage, forwarding_callback = self.maybe_fulfill_htlc(
chan=chan, # fixme: do we need the outer key?
htlc=htlc, payment_key, preimage, forwarding_callback = self.maybe_fulfill_htlc(
processed_onion=processed_onion, chan=chan,
onion_packet_bytes=onion_packet_bytes) htlc=htlc,
if forwarding_callback: processed_onion=processed_onion,
payment_secret = processed_onion.hop_data.payload["payment_data"]["payment_secret"] onion_packet_bytes=onion_packet_bytes,
payment_key = payment_hash + payment_secret already_forwarded=bool(forwarding_key))
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding: if not forwarding_key:
return None, None, None if forwarding_callback:
elif payment_key in self.lnworker.final_onion_forwardings: # HTLC we are supposed to forward, but haven't forwarded yet
# we are already forwarding this payment if not self.lnworker.enable_htlc_forwarding:
self.logger.info(f"we are already forwarding this.") return None, None, None
else: if payment_key not in self.lnworker.active_forwardings:
# add to list of ongoing payments async def wrapped_callback():
self.lnworker.final_onion_forwardings.add(payment_key) forwarding_coro = forwarding_callback()
# clear previous failures try:
self.lnworker.final_onion_forwarding_failures.pop(payment_key, None) next_htlc = await forwarding_coro
async def wrapped_callback(): if next_htlc:
forwarding_coro = forwarding_callback() self.lnworker.downstream_to_upstream_htlc[next_htlc] = htlc_key
try: except OnionRoutingFailure as e:
await forwarding_coro self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
except OnionRoutingFailure as e: # add to list
self.lnworker.final_onion_forwarding_failures[payment_key] = e self.lnworker.active_forwardings[payment_key] = True
finally: fut = asyncio.ensure_future(wrapped_callback())
# remove from list of payments, so that another attempt can be initiated # return payment_key so this branch will not be executed again
self.lnworker.final_onion_forwardings.remove(payment_key) return None, payment_key, None
asyncio.ensure_future(wrapped_callback())
# return fw_info so that maybe_fulfill_htlc will not be called again
fw_info = (payment_key.hex(), -1)
return None, fw_info, None
else:
# trampoline- HTLC we are supposed to forward, and have already forwarded
payment_key_outer_onion = bytes.fromhex(forwarding_info[0])
preimage = self.lnworker.get_preimage(payment_hash)
payment_secret_inner_onion = self.lnworker.get_payment_secret(payment_hash)
payment_key_inner_onion = payment_hash + payment_secret_inner_onion
for payment_key in [payment_key_inner_onion, payment_key_outer_onion]:
# get (and not pop) failure because the incoming payment might be multi-part
error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key)
if error_reason:
self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}')
raise error_reason
elif not forwarding_info:
# HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
next_chan_id, next_htlc_id = self.maybe_forward_htlc(
incoming_chan=chan,
htlc=htlc,
processed_onion=processed_onion)
fw_info = (next_chan_id.hex(), next_htlc_id)
return None, fw_info, None
else: else:
assert payment_key == forwarding_key
# HTLC we are supposed to forward, and have already forwarded # HTLC we are supposed to forward, and have already forwarded
preimage = self.lnworker.get_preimage(payment_hash) preimage = self.lnworker.get_preimage(payment_hash)
next_chan_id_hex, htlc_id = forwarding_info error_bytes, error_reason = self.lnworker.get_forwarding_failure(payment_key)
next_chan = self.lnworker.get_channel_by_short_id(bytes.fromhex(next_chan_id_hex))
if next_chan: if error_bytes or error_reason or preimage:
error_bytes, error_reason = next_chan.pop_fail_htlc_reason(htlc_id) self.lnworker.maybe_cleanup_forwarding(payment_key, chan.get_scid_or_local_alias(), htlc)
if error_bytes:
return None, None, error_bytes if error_bytes:
if error_reason: return None, None, error_bytes
raise error_reason if error_reason:
# just-in-time channel raise error_reason
if htlc_id == -1:
error_reason = self.jit_failures.pop(next_chan_id_hex, None)
if error_reason:
raise error_reason
if preimage: if preimage:
return preimage, None, None return preimage, None, None
return None, None, None return None, None, None

7
electrum/lnutil.py

@ -66,6 +66,13 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v) json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
def serialize_htlc_key(scid:bytes, htlc_id: int):
return scid.hex() + ':%d'%htlc_id
def deserialize_htlc_key(htlc_key:str):
scid, htlc_id = htlc_key.split(':')
return bytes.fromhex(scid), int(htlc_id)
@attr.s @attr.s
class OnlyPubkeyKeypair(StoredObject): class OnlyPubkeyKeypair(StoredObject):
pubkey = attr.ib(type=bytes, converter=hex_to_bytes) pubkey = attr.ib(type=bytes, converter=hex_to_bytes)

86
electrum/lnworker.py

@ -56,6 +56,7 @@ from .lnchannel import ChannelState, PeerState, HTLCWithStatus
from .lnrater import LNRater from .lnrater import LNRater
from . import lnutil from . import lnutil
from .lnutil import funding_output_script from .lnutil import funding_output_script
from .lnutil import serialize_htlc_key, deserialize_htlc_key
from .bitcoin import redeem_script_to_address, DummyAddress from .bitcoin import redeem_script_to_address, DummyAddress
from .lnutil import (Outpoint, LNPeerAddr, from .lnutil import (Outpoint, LNPeerAddr,
get_compressed_pubkey_from_bech32, extract_nodeid, get_compressed_pubkey_from_bech32, extract_nodeid,
@ -840,10 +841,11 @@ class LNWallet(LNWorker):
for payment_hash in self.get_payments(status='inflight').keys(): for payment_hash in self.get_payments(status='inflight').keys():
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
self.final_onion_forwardings = set() # payment forwarding
self.final_onion_forwarding_failures = {} # todo: should be persisted self.active_forwardings = self.db.get_dict('active_forwardings') # list of payment_keys
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys self.forwarding_failures = self.db.get_dict('forwarding_failures') # payment_key -> (error_bytes, error_message)
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted)
# payment_hash -> callback: # payment_hash -> callback:
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]] self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
self.payment_bundles = [] # lists of hashes. todo:persist self.payment_bundles = [] # lists of hashes. todo:persist
@ -1275,9 +1277,10 @@ class LNWallet(LNWorker):
except Exception: except Exception:
raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.TEMPORARY_NODE_FAILURE, data=b'')
# We have been paid and can broadcast # We have been paid and can broadcast
# if broadcasting raise an exception, we should try to rebroadcast # todo: if broadcasting raise an exception, we should try to rebroadcast
await self.network.broadcast_transaction(funding_tx) await self.network.broadcast_transaction(funding_tx)
return next_chan, funding_tx htlc_key = serialize_htlc_key(next_chan.get_scid_or_local_alias(), htlc.htlc_id)
return htlc_key
@log_exceptions @log_exceptions
async def open_channel_with_peer( async def open_channel_with_peer(
@ -2267,7 +2270,6 @@ class LNWallet(LNWorker):
if pkey in self.received_mpp_htlcs: if pkey in self.received_mpp_htlcs:
self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution) self.set_mpp_resolution(payment_key=pkey, resolution=mpp_resolution)
self.maybe_cleanup_mpp_status(payment_key, short_channel_id, htlc)
return mpp_resolution return mpp_resolution
def update_mpp_with_received_htlc( def update_mpp_with_received_htlc(
@ -2312,20 +2314,31 @@ class LNWallet(LNWorker):
return int(time.time()) return int(time.time())
return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set]) return min([_htlc.timestamp for scid, _htlc in mpp_status.htlc_set])
def maybe_cleanup_mpp_status( def maybe_cleanup_forwarding(
self, self,
payment_key: bytes, payment_key_hex: str,
short_channel_id: ShortChannelID, short_channel_id: ShortChannelID,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
) -> None: ) -> None:
mpp_status = self.received_mpp_htlcs[payment_key]
if mpp_status.resolution == RecvMPPResolution.WAITING: is_htlc_key = ':' in payment_key_hex
return if not is_htlc_key:
key = (short_channel_id, htlc) payment_key = bytes.fromhex(payment_key_hex)
mpp_status.htlc_set.remove(key) # side-effecting htlc_set mpp_status = self.received_mpp_htlcs[payment_key]
if not mpp_status.htlc_set and payment_key in self.received_mpp_htlcs: if mpp_status.resolution == RecvMPPResolution.WAITING:
# reconstructing the MPP after restart
self.logger.info(f'cannot cleanup mpp, still waiting')
return
htlc_key = (short_channel_id, htlc)
mpp_status.htlc_set.remove(htlc_key) # side-effecting htlc_set
if mpp_status.htlc_set:
return
self.logger.info('cleaning up mpp')
self.received_mpp_htlcs.pop(payment_key) self.received_mpp_htlcs.pop(payment_key)
self.active_forwardings.pop(payment_key_hex, None)
self.forwarding_failures.pop(payment_key_hex, None)
def get_payment_status(self, payment_hash: bytes) -> int: def get_payment_status(self, payment_hash: bytes) -> int:
info = self.get_payment_info(payment_hash) info = self.get_payment_info(payment_hash)
return info.status if info else PR_UNPAID return info.status if info else PR_UNPAID
@ -2370,16 +2383,23 @@ class LNWallet(LNWorker):
info = info._replace(status=status) info = info._replace(status=status)
self.save_payment_info(info) self.save_payment_info(info)
def is_forwarded_htlc_notify(self, chan: Channel, htlc_id: int) -> bool: def is_forwarded_htlc_notify(
self, chan: Channel, htlc_id: int, *,
error_bytes: Optional[bytes] = None,
failure_message: Optional['OnionRoutingFailure'] = None
) -> bool:
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
If we find this was a forwarded HTLC, the upstream peer is notified. If we find this was a forwarded HTLC, the upstream peer is notified.
Returns whether this was a forwarded HTLC. Returns whether this was a forwarded HTLC.
""" """
fw_info = chan.get_scid_or_local_alias().hex(), htlc_id htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc_id)
upstream_peer_pubkey = self.downstream_htlc_to_upstream_peer_map.get(fw_info) upstream_key = self.downstream_to_upstream_htlc.pop(htlc_key, None)
if not upstream_peer_pubkey: if not upstream_key:
return False return False
upstream_peer = self.peers.get(upstream_peer_pubkey) self.save_forwarding_failure(upstream_key, error_bytes=error_bytes, failure_message=failure_message)
upstream_chan_scid, _ = deserialize_htlc_key(upstream_key)
upstream_chan = self.get_channel_by_short_id(upstream_chan_scid)
upstream_peer = self.peers.get(upstream_chan.node_id) if upstream_chan else None
if upstream_peer: if upstream_peer:
upstream_peer.downstream_htlc_resolved_event.set() upstream_peer.downstream_htlc_resolved_event.set()
upstream_peer.downstream_htlc_resolved_event.clear() upstream_peer.downstream_htlc_resolved_event.clear()
@ -2416,7 +2436,7 @@ class LNWallet(LNWorker):
failure_message: Optional['OnionRoutingFailure']): failure_message: Optional['OnionRoutingFailure']):
util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
if self.is_forwarded_htlc_notify(chan=chan, htlc_id=htlc_id): if self.is_forwarded_htlc_notify(chan=chan, htlc_id=htlc_id, error_bytes=error_bytes, failure_message=failure_message):
return return
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)): if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
onion_key = chan.pop_onion_key(htlc_id) onion_key = chan.pop_onion_key(htlc_id)
@ -3026,7 +3046,17 @@ class LNWallet(LNWorker):
util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def fail_final_onion_forwarding(self, payment_key): def save_forwarding_failure(
""" use this to fail htlcs received for hold invoices""" self, payment_key:str, *,
e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') error_bytes: Optional[bytes] = None,
self.final_onion_forwarding_failures[payment_key] = e failure_message: Optional['OnionRoutingFailure'] = None):
error_hex = error_bytes.hex() if error_bytes else None
failure_hex = failure_message.to_bytes().hex() if failure_message else None
self.forwarding_failures[payment_key] = (error_hex, failure_hex)
def get_forwarding_failure(self, payment_key: str):
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
error_bytes = bytes.fromhex(error_hex) if error_hex else None
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None
return error_bytes, failure_message

4
electrum/submarine_swaps.py

@ -30,6 +30,7 @@ from .bitcoin import construct_script
from .crypto import ripemd from .crypto import ripemd
from .invoices import Invoice from .invoices import Invoice
from .network import TxBroadcastServerReturnedError from .network import TxBroadcastServerReturnedError
from .lnonion import OnionRoutingFailure, OnionFailureCode
if TYPE_CHECKING: if TYPE_CHECKING:
@ -234,7 +235,8 @@ class SwapManager(Logger):
self.lnworker.unregister_hold_invoice(swap.payment_hash) self.lnworker.unregister_hold_invoice(swap.payment_hash)
payment_secret = self.lnworker.get_payment_secret(swap.payment_hash) payment_secret = self.lnworker.get_payment_secret(swap.payment_hash)
payment_key = swap.payment_hash + payment_secret payment_key = swap.payment_hash + payment_secret
self.lnworker.fail_final_onion_forwarding(payment_key) e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
self.lnworker.save_forwarding_failure(payment_key.hex(), failure_message=e)
self.lnwatcher.remove_callback(swap.lockup_address) self.lnwatcher.remove_callback(swap.lockup_address)
if swap.funding_txid is None: if swap.funding_txid is None:
self.swaps.pop(swap.payment_hash.hex()) self.swaps.pop(swap.payment_hash.hex())

14
electrum/tests/test_lnpeer.py

@ -171,12 +171,12 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self._paysessions = dict() self._paysessions = dict()
self.sent_htlcs_info = dict() self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set) self.sent_buckets = defaultdict(set)
self.final_onion_forwardings = set() self.active_forwardings = {}
self.final_onion_forwarding_failures = {} self.forwarding_failures = {}
self.inflight_payments = set() self.inflight_payments = set()
self.preimages = {} self.preimages = {}
self.stopping_soon = False self.stopping_soon = False
self.downstream_htlc_to_upstream_peer_map = {} self.downstream_to_upstream_htlc = {}
self.hold_invoice_callbacks = {} self.hold_invoice_callbacks = {}
self.payment_bundles = [] # lists of hashes. todo:persist self.payment_bundles = [] # lists of hashes. todo:persist
self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0 self.config.INITIAL_TRAMPOLINE_FEE_LEVEL = 0
@ -298,10 +298,12 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
set_mpp_resolution = LNWallet.set_mpp_resolution set_mpp_resolution = LNWallet.set_mpp_resolution
is_mpp_amount_reached = LNWallet.is_mpp_amount_reached is_mpp_amount_reached = LNWallet.is_mpp_amount_reached
get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp get_first_timestamp_of_mpp = LNWallet.get_first_timestamp_of_mpp
maybe_cleanup_mpp_status = LNWallet.maybe_cleanup_mpp_status
bundle_payments = LNWallet.bundle_payments bundle_payments = LNWallet.bundle_payments
get_payment_bundle = LNWallet.get_payment_bundle get_payment_bundle = LNWallet.get_payment_bundle
_get_payment_key = LNWallet._get_payment_key _get_payment_key = LNWallet._get_payment_key
save_forwarding_failure = LNWallet.save_forwarding_failure
get_forwarding_failure = LNWallet.get_forwarding_failure
maybe_cleanup_forwarding = LNWallet.maybe_cleanup_forwarding
class MockTransport: class MockTransport:
@ -1821,6 +1823,8 @@ class TestPeerForwarding(TestPeer):
test_failure=False, test_failure=False,
attempts=2): attempts=2):
bob_w = graph.workers['bob']
carol_w = graph.workers['carol']
dave_w = graph.workers['dave'] dave_w = graph.workers['dave']
async def pay(lnaddr, pay_req): async def pay(lnaddr, pay_req):
@ -1828,6 +1832,8 @@ class TestPeerForwarding(TestPeer):
result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts) result, log = await graph.workers['alice'].pay_invoice(pay_req, attempts=attempts)
if result: if result:
self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash)) self.assertEqual(PR_PAID, graph.workers['dave'].get_payment_status(lnaddr.paymenthash))
self.assertFalse(bool(bob_w.active_forwardings))
self.assertFalse(bool(carol_w.active_forwardings))
raise PaymentDone() raise PaymentDone()
else: else:
raise NoPathFound() raise NoPathFound()

Loading…
Cancel
Save