Browse Source

lnpeer and lnworker cleanup:

- rename trampoline_forwardings -> final_onion_forwardings,
   because this dict is used for both trampoline and hold invoices
 - remove timeout from hold_invoice_callbacks (redundant with invoice)
 - add test_failure boolean parameter to TestPeer._test_simple_payment,
   in order to test correct propagation of OnionRoutingFailures.
 - maybe_fulfill_htlc: raise an OnionRoutingFailure if we do not have
   the preimage for a payment that does not have a hold invoice callback.
   Without this, the above unit tests stall when we use test_failure=True
master
ThomasV 2 years ago
parent
commit
bf86cd6761
  1. 23
      electrum/lnpeer.py
  2. 15
      electrum/lnworker.py
  3. 4
      electrum/submarine_swaps.py
  4. 35
      electrum/tests/test_lnpeer.py

23
electrum/lnpeer.py

@ -1887,12 +1887,7 @@ class Peer(Logger):
if preimage:
return preimage, None
else:
# for hold invoices, trigger callback
cb, timeout = hold_invoice_callback
if int(time.time()) < timeout:
return None, lambda: cb(payment_hash)
else:
raise exc_incorrect_or_unknown_pd
return None, lambda: hold_invoice_callback(payment_hash)
# TODO don't accept payments twice for same invoice
# TODO check invoice expiry
@ -1903,8 +1898,8 @@ class Peer(Logger):
preimage = self.lnworker.get_preimage(payment_hash)
if not preimage:
self.logger.info(f"missing callback {payment_hash.hex()}")
return None, None
self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}")
raise exc_incorrect_or_unknown_pd
expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)]
expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices
@ -2424,23 +2419,23 @@ class Peer(Logger):
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
pass
elif payment_key in self.lnworker.trampoline_forwardings:
elif payment_key in self.lnworker.final_onion_forwardings:
# we are already forwarding this payment
self.logger.info(f"we are already forwarding this.")
else:
# add to list of ongoing payments
self.lnworker.trampoline_forwardings.add(payment_key)
self.lnworker.final_onion_forwardings.add(payment_key)
# clear previous failures
self.lnworker.trampoline_forwarding_failures.pop(payment_key, None)
self.lnworker.final_onion_forwarding_failures.pop(payment_key, None)
async def wrapped_callback():
forwarding_coro = forwarding_callback()
try:
await forwarding_coro
except OnionRoutingFailure as e:
self.lnworker.trampoline_forwarding_failures[payment_key] = e
self.lnworker.final_onion_forwarding_failures[payment_key] = e
finally:
# remove from list of payments, so that another attempt can be initiated
self.lnworker.trampoline_forwardings.remove(payment_key)
self.lnworker.final_onion_forwardings.remove(payment_key)
asyncio.ensure_future(wrapped_callback())
fw_info = payment_key.hex()
return None, fw_info, None
@ -2449,7 +2444,7 @@ class Peer(Logger):
payment_key = bytes.fromhex(forwarding_info)
preimage = self.lnworker.get_preimage(payment_hash)
# get (and not pop) failure because the incoming payment might be multi-part
error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_key)
error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key)
if error_reason:
self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}')
raise error_reason

15
electrum/lnworker.py

@ -703,8 +703,8 @@ class LNWallet(LNWorker):
for payment_hash in self.get_payments(status='inflight').keys():
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
self.trampoline_forwardings = set()
self.trampoline_forwarding_failures = {} # todo: should be persisted
self.final_onion_forwardings = set()
self.final_onion_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
# payment_hash -> callback, timeout:
@ -1954,11 +1954,8 @@ class LNWallet(LNWorker):
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
self.save_payment_info(info, write_to_disk=False)
def register_callback_for_hold_invoice(
self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int,
):
expiry = int(time.time()) + timeout
self.hold_invoice_callbacks[payment_hash] = cb, expiry
def register_callback_for_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], None]):
self.hold_invoice_callbacks[payment_hash] = cb
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
key = info.payment_hash.hex()
@ -2758,7 +2755,7 @@ class LNWallet(LNWorker):
util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def fail_trampoline_forwarding(self, payment_key):
def fail_final_onion_forwarding(self, payment_key):
""" use this to fail htlcs received for hold invoices"""
e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
self.trampoline_forwarding_failures[payment_key] = e
self.final_onion_forwarding_failures[payment_key] = e

4
electrum/submarine_swaps.py

@ -247,7 +247,7 @@ class SwapManager(Logger):
self.logger.info(f'found confirmed refund')
payment_secret = self.lnworker.get_payment_secret(swap.payment_hash)
payment_key = swap.payment_hash + payment_secret
self.lnworker.fail_trampoline_forwarding(payment_key)
self.lnworker.fail_final_onion_forwarding(payment_key)
if delta < 0:
# too early for refund
@ -343,7 +343,7 @@ class SwapManager(Logger):
)
# add payment info to lnworker
self.lnworker.add_payment_info_for_hold_invoice(payment_hash, main_amount_sat)
self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback, 60*60*24)
self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback)
prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000)
_, prepay_invoice = self.lnworker.get_bolt11_invoice(
payment_hash=prepay_hash,

35
electrum/tests/test_lnpeer.py

@ -36,7 +36,7 @@ from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnonion import OnionFailureCode
from electrum.lnonion import OnionFailureCode, OnionRoutingFailure
from electrum.lnutil import UpdateAddHtlc
from electrum.lnutil import LOCAL, REMOTE
from electrum.invoices import PR_PAID, PR_UNPAID
@ -169,8 +169,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.sent_htlcs_q = defaultdict(asyncio.Queue)
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.trampoline_forwardings = set()
self.trampoline_forwarding_failures = {}
self.final_onion_forwardings = set()
self.final_onion_forwarding_failures = {}
self.inflight_payments = set()
self.preimages = {}
self.stopping_soon = False
@ -749,6 +749,7 @@ class TestPeer(ElectrumTestCase):
async def _test_simple_payment(
self,
test_trampoline: bool,
test_failure:bool=False,
test_hold_invoice=False,
test_bundle=False,
test_bundle_timeout=False
@ -765,12 +766,16 @@ class TestPeer(ElectrumTestCase):
else:
raise PaymentFailure()
lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice:
payment_hash = lnaddr.paymenthash
payment_hash = lnaddr.paymenthash
if test_failure or test_hold_invoice:
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
async def cb(payment_hash):
w2.save_preimage(payment_hash, preimage)
w2.register_callback_for_hold_invoice(payment_hash, cb, 60)
if test_hold_invoice:
async def cb(payment_hash):
if not test_failure:
w2.save_preimage(payment_hash, preimage)
else:
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
w2.register_callback_for_hold_invoice(payment_hash, cb)
if test_bundle:
lnaddr2, pay_req2 = self.prepare_invoice(w2)
@ -799,11 +804,16 @@ class TestPeer(ElectrumTestCase):
await f()
@needs_test_with_all_chacha20_implementations
async def test_simple_payment(self):
async def test_simple_payment_success(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(test_trampoline=test_trampoline)
async def test_simple_payment_failure(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_failure=True)
async def test_payment_bundle(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
@ -814,11 +824,16 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True)
async def test_simple_payment_with_hold_invoice(self):
async def test_simple_payment_success_with_hold_invoice(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone):
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True)
async def test_simple_payment_failure_with_hold_invoice(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True)
@needs_test_with_all_chacha20_implementations
async def test_payment_race(self):
"""Alice and Bob pay each other simultaneously.

Loading…
Cancel
Save