Browse Source

lnpeer and lnworker cleanup:

- rename trampoline_forwardings -> final_onion_forwardings,
   because this dict is used for both trampoline and hold invoices
 - remove timeout from hold_invoice_callbacks (redundant with invoice)
 - add test_failure boolean parameter to TestPeer._test_simple_payment,
   in order to test correct propagation of OnionRoutingFailures.
 - maybe_fulfill_htlc: raise an OnionRoutingFailure if we do not have
   the preimage for a payment that does not have a hold invoice callback.
   Without this, the above unit tests stall when we use test_failure=True
master
ThomasV 2 years ago
parent
commit
bf86cd6761
  1. 23
      electrum/lnpeer.py
  2. 15
      electrum/lnworker.py
  3. 4
      electrum/submarine_swaps.py
  4. 35
      electrum/tests/test_lnpeer.py

23
electrum/lnpeer.py

@ -1887,12 +1887,7 @@ class Peer(Logger):
if preimage: if preimage:
return preimage, None return preimage, None
else: else:
# for hold invoices, trigger callback return None, lambda: hold_invoice_callback(payment_hash)
cb, timeout = hold_invoice_callback
if int(time.time()) < timeout:
return None, lambda: cb(payment_hash)
else:
raise exc_incorrect_or_unknown_pd
# TODO don't accept payments twice for same invoice # TODO don't accept payments twice for same invoice
# TODO check invoice expiry # TODO check invoice expiry
@ -1903,8 +1898,8 @@ class Peer(Logger):
preimage = self.lnworker.get_preimage(payment_hash) preimage = self.lnworker.get_preimage(payment_hash)
if not preimage: if not preimage:
self.logger.info(f"missing callback {payment_hash.hex()}") self.logger.info(f"missing preimage and no hold invoice callback {payment_hash.hex()}")
return None, None raise exc_incorrect_or_unknown_pd
expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)] expected_payment_secrets = [self.lnworker.get_payment_secret(htlc.payment_hash)]
expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices expected_payment_secrets.append(derive_payment_secret_from_payment_preimage(preimage)) # legacy secret for old invoices
@ -2424,23 +2419,23 @@ class Peer(Logger):
# trampoline- HTLC we are supposed to forward, but haven't forwarded yet # trampoline- HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding: if not self.lnworker.enable_htlc_forwarding:
pass pass
elif payment_key in self.lnworker.trampoline_forwardings: elif payment_key in self.lnworker.final_onion_forwardings:
# we are already forwarding this payment # we are already forwarding this payment
self.logger.info(f"we are already forwarding this.") self.logger.info(f"we are already forwarding this.")
else: else:
# add to list of ongoing payments # add to list of ongoing payments
self.lnworker.trampoline_forwardings.add(payment_key) self.lnworker.final_onion_forwardings.add(payment_key)
# clear previous failures # clear previous failures
self.lnworker.trampoline_forwarding_failures.pop(payment_key, None) self.lnworker.final_onion_forwarding_failures.pop(payment_key, None)
async def wrapped_callback(): async def wrapped_callback():
forwarding_coro = forwarding_callback() forwarding_coro = forwarding_callback()
try: try:
await forwarding_coro await forwarding_coro
except OnionRoutingFailure as e: except OnionRoutingFailure as e:
self.lnworker.trampoline_forwarding_failures[payment_key] = e self.lnworker.final_onion_forwarding_failures[payment_key] = e
finally: finally:
# remove from list of payments, so that another attempt can be initiated # remove from list of payments, so that another attempt can be initiated
self.lnworker.trampoline_forwardings.remove(payment_key) self.lnworker.final_onion_forwardings.remove(payment_key)
asyncio.ensure_future(wrapped_callback()) asyncio.ensure_future(wrapped_callback())
fw_info = payment_key.hex() fw_info = payment_key.hex()
return None, fw_info, None return None, fw_info, None
@ -2449,7 +2444,7 @@ class Peer(Logger):
payment_key = bytes.fromhex(forwarding_info) payment_key = bytes.fromhex(forwarding_info)
preimage = self.lnworker.get_preimage(payment_hash) preimage = self.lnworker.get_preimage(payment_hash)
# get (and not pop) failure because the incoming payment might be multi-part # get (and not pop) failure because the incoming payment might be multi-part
error_reason = self.lnworker.trampoline_forwarding_failures.get(payment_key) error_reason = self.lnworker.final_onion_forwarding_failures.get(payment_key)
if error_reason: if error_reason:
self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}') self.logger.info(f'trampoline forwarding failure: {error_reason.code_name()}')
raise error_reason raise error_reason

15
electrum/lnworker.py

@ -703,8 +703,8 @@ class LNWallet(LNWorker):
for payment_hash in self.get_payments(status='inflight').keys(): for payment_hash in self.get_payments(status='inflight').keys():
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
self.trampoline_forwardings = set() self.final_onion_forwardings = set()
self.trampoline_forwarding_failures = {} # todo: should be persisted self.final_onion_forwarding_failures = {} # todo: should be persisted
# map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys # map forwarded htlcs (fw_info=(scid_hex, htlc_id)) to originating peer pubkeys
self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes] self.downstream_htlc_to_upstream_peer_map = {} # type: Dict[Tuple[str, int], bytes]
# payment_hash -> callback, timeout: # payment_hash -> callback, timeout:
@ -1954,11 +1954,8 @@ class LNWallet(LNWorker):
info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID) info = PaymentInfo(payment_hash, lightning_amount_sat * 1000, RECEIVED, PR_UNPAID)
self.save_payment_info(info, write_to_disk=False) self.save_payment_info(info, write_to_disk=False)
def register_callback_for_hold_invoice( def register_callback_for_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], None]):
self, payment_hash: bytes, cb: Callable[[bytes], None], timeout: int, self.hold_invoice_callbacks[payment_hash] = cb
):
expiry = int(time.time()) + timeout
self.hold_invoice_callbacks[payment_hash] = cb, expiry
def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None: def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) -> None:
key = info.payment_hash.hex() key = info.payment_hash.hex()
@ -2758,7 +2755,7 @@ class LNWallet(LNWorker):
util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
def fail_trampoline_forwarding(self, payment_key): def fail_final_onion_forwarding(self, payment_key):
""" use this to fail htlcs received for hold invoices""" """ use this to fail htlcs received for hold invoices"""
e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') e = OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')
self.trampoline_forwarding_failures[payment_key] = e self.final_onion_forwarding_failures[payment_key] = e

4
electrum/submarine_swaps.py

@ -247,7 +247,7 @@ class SwapManager(Logger):
self.logger.info(f'found confirmed refund') self.logger.info(f'found confirmed refund')
payment_secret = self.lnworker.get_payment_secret(swap.payment_hash) payment_secret = self.lnworker.get_payment_secret(swap.payment_hash)
payment_key = swap.payment_hash + payment_secret payment_key = swap.payment_hash + payment_secret
self.lnworker.fail_trampoline_forwarding(payment_key) self.lnworker.fail_final_onion_forwarding(payment_key)
if delta < 0: if delta < 0:
# too early for refund # too early for refund
@ -343,7 +343,7 @@ class SwapManager(Logger):
) )
# add payment info to lnworker # add payment info to lnworker
self.lnworker.add_payment_info_for_hold_invoice(payment_hash, main_amount_sat) self.lnworker.add_payment_info_for_hold_invoice(payment_hash, main_amount_sat)
self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback, 60*60*24) self.lnworker.register_callback_for_hold_invoice(payment_hash, self.hold_invoice_callback)
prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000) prepay_hash = self.lnworker.create_payment_info(amount_msat=prepay_amount_sat*1000)
_, prepay_invoice = self.lnworker.get_bolt11_invoice( _, prepay_invoice = self.lnworker.get_bolt11_invoice(
payment_hash=prepay_hash, payment_hash=prepay_hash,

35
electrum/tests/test_lnpeer.py

@ -36,7 +36,7 @@ from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger from electrum.logging import console_stderr_handler, Logger
from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnworker import PaymentInfo, RECEIVED
from electrum.lnonion import OnionFailureCode from electrum.lnonion import OnionFailureCode, OnionRoutingFailure
from electrum.lnutil import UpdateAddHtlc from electrum.lnutil import UpdateAddHtlc
from electrum.lnutil import LOCAL, REMOTE from electrum.lnutil import LOCAL, REMOTE
from electrum.invoices import PR_PAID, PR_UNPAID from electrum.invoices import PR_PAID, PR_UNPAID
@ -169,8 +169,8 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.sent_htlcs_q = defaultdict(asyncio.Queue) self.sent_htlcs_q = defaultdict(asyncio.Queue)
self.sent_htlcs_info = dict() self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set) self.sent_buckets = defaultdict(set)
self.trampoline_forwardings = set() self.final_onion_forwardings = set()
self.trampoline_forwarding_failures = {} self.final_onion_forwarding_failures = {}
self.inflight_payments = set() self.inflight_payments = set()
self.preimages = {} self.preimages = {}
self.stopping_soon = False self.stopping_soon = False
@ -749,6 +749,7 @@ class TestPeer(ElectrumTestCase):
async def _test_simple_payment( async def _test_simple_payment(
self, self,
test_trampoline: bool, test_trampoline: bool,
test_failure:bool=False,
test_hold_invoice=False, test_hold_invoice=False,
test_bundle=False, test_bundle=False,
test_bundle_timeout=False test_bundle_timeout=False
@ -765,12 +766,16 @@ class TestPeer(ElectrumTestCase):
else: else:
raise PaymentFailure() raise PaymentFailure()
lnaddr, pay_req = self.prepare_invoice(w2) lnaddr, pay_req = self.prepare_invoice(w2)
if test_hold_invoice: payment_hash = lnaddr.paymenthash
payment_hash = lnaddr.paymenthash if test_failure or test_hold_invoice:
preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex())) preimage = bytes.fromhex(w2.preimages.pop(payment_hash.hex()))
async def cb(payment_hash): if test_hold_invoice:
w2.save_preimage(payment_hash, preimage) async def cb(payment_hash):
w2.register_callback_for_hold_invoice(payment_hash, cb, 60) if not test_failure:
w2.save_preimage(payment_hash, preimage)
else:
raise OnionRoutingFailure(code=OnionFailureCode.INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS, data=b'')
w2.register_callback_for_hold_invoice(payment_hash, cb)
if test_bundle: if test_bundle:
lnaddr2, pay_req2 = self.prepare_invoice(w2) lnaddr2, pay_req2 = self.prepare_invoice(w2)
@ -799,11 +804,16 @@ class TestPeer(ElectrumTestCase):
await f() await f()
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
async def test_simple_payment(self): async def test_simple_payment_success(self):
for test_trampoline in [False, True]: for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
await self._test_simple_payment(test_trampoline=test_trampoline) await self._test_simple_payment(test_trampoline=test_trampoline)
async def test_simple_payment_failure(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_failure=True)
async def test_payment_bundle(self): async def test_payment_bundle(self):
for test_trampoline in [False, True]: for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
@ -814,11 +824,16 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True) await self._test_simple_payment(test_trampoline=test_trampoline, test_bundle=True, test_bundle_timeout=True)
async def test_simple_payment_with_hold_invoice(self): async def test_simple_payment_success_with_hold_invoice(self):
for test_trampoline in [False, True]: for test_trampoline in [False, True]:
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True) await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True)
async def test_simple_payment_failure_with_hold_invoice(self):
for test_trampoline in [False, True]:
with self.assertRaises(PaymentFailure):
await self._test_simple_payment(test_trampoline=test_trampoline, test_hold_invoice=True, test_failure=True)
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
async def test_payment_race(self): async def test_payment_race(self):
"""Alice and Bob pay each other simultaneously. """Alice and Bob pay each other simultaneously.

Loading…
Cancel
Save