Browse Source

lnworker: clean-up sent_htlcs_q and sent_htlcs_info

- introduce SentHtlcInfo named tuple
  - some previously unnamed tuples are now much shorter:
    create_routes_for_payment no longer returns an 8-tuple!
- sent_htlcs_q (renamed from sent_htlcs), is now keyed on payment_hash+payment_secret
  (needed for proper trampoline forwarding)
master
SomberNight 2 years ago
parent
commit
afac158c80
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/lnpeer.py
  2. 137
      electrum/lnworker.py
  3. 85
      electrum/tests/test_lnpeer.py

2
electrum/lnpeer.py

@ -1742,6 +1742,8 @@ class Peer(Logger):
except OnionRoutingFailure as e:
raise
except PaymentFailure as e:
self.logger.debug(
f"maybe_forward_trampoline. PaymentFailure for {payment_hash.hex()=}, {payment_secret.hex()=}: {e!r}")
# FIXME: adapt the error code
raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')

137
electrum/lnworker.py

@ -69,7 +69,7 @@ from .lnutil import (Outpoint, LNPeerAddr,
NoPathFound, InvalidGossipMsg)
from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures
from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
from .lnonion import OnionFailureCode, OnionRoutingFailure
from .lnonion import OnionFailureCode, OnionRoutingFailure, OnionPacket
from .lnmsg import decode_msg
from .i18n import _
from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_to_use,
@ -181,6 +181,20 @@ class ReceivedMPPStatus(NamedTuple):
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
class SentHtlcInfo(NamedTuple):
route: LNPaymentRoute
payment_secret_orig: bytes
payment_secret_bucket: bytes
amount_msat: int
bucket_msat: int
amount_receiver_msat: int
trampoline_fee_level: Optional[int]
trampoline_route: Optional[LNPaymentRoute]
class ErrorAddingPeer(Exception): pass
@ -678,8 +692,8 @@ class LNWallet(LNWorker):
for channel_id, storage in channel_backups.items():
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level
self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
@ -1268,7 +1282,8 @@ class LNWallet(LNWorker):
if fwd_trampoline_cltv_delta < 576:
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
self.logs[payment_hash.hex()] = log = []
payment_key = payment_hash + payment_secret
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
# when encountering trampoline forwarding difficulties in the legacy case, we
# sometimes need to fall back to a single trampoline forwarder, at the expense
@ -1300,28 +1315,24 @@ class LNWallet(LNWorker):
channels=channels,
)
# 2. send htlcs
async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion, trampoline_route in routes:
amount_inflight += amount_receiver_msat
async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
amount_inflight += sent_htlc_info.amount_receiver_msat
if amount_inflight > amount_to_pay: # safety belts
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
await self.pay_to_route(
route=route,
amount_msat=amount_msat,
total_msat=total_msat,
amount_receiver_msat=amount_receiver_msat,
sent_htlc_info=sent_htlc_info,
payment_hash=payment_hash,
payment_secret=bucket_payment_secret,
min_cltv_expiry=cltv_delta,
trampoline_onion=trampoline_onion,
trampoline_fee_level=self.trampoline_fee_level,
trampoline_route=trampoline_route)
)
# invoice_status is triggered in self.set_invoice_status when it actally changes.
# It is also triggered here to update progress for a lightning payment in the GUI
# (e.g. attempt counter)
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
# 3. await a queue
self.logger.info(f"amount inflight {amount_inflight}")
htlc_log = await self.sent_htlcs[payment_hash].get()
htlc_log = await self.sent_htlcs_q[payment_key].get()
amount_inflight -= htlc_log.amount_msat
if amount_inflight < 0:
raise Exception(f"amount_inflight={amount_inflight} < 0")
@ -1394,48 +1405,44 @@ class LNWallet(LNWorker):
async def pay_to_route(
self, *,
route: LNPaymentRoute,
amount_msat: int,
total_msat: int,
amount_receiver_msat:int,
sent_htlc_info: SentHtlcInfo,
payment_hash: bytes,
payment_secret: bytes,
min_cltv_expiry: int,
trampoline_onion: bytes = None,
trampoline_fee_level: int,
trampoline_route: Optional[List]) -> None:
# send a single htlc
short_channel_id = route[0].short_channel_id
) -> None:
"""Sends a single HTLC."""
shi = sent_htlc_info
del sent_htlc_info # just renamed
short_channel_id = shi.route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id)
assert chan, ShortChannelID(short_channel_id)
peer = self._peers.get(route[0].node_id)
peer = self._peers.get(shi.route[0].node_id)
if not peer:
raise PaymentFailure('Dropped peer')
await peer.initialized
htlc = peer.pay(
route=route,
route=shi.route,
chan=chan,
amount_msat=amount_msat,
total_msat=total_msat,
amount_msat=shi.amount_msat,
total_msat=shi.bucket_msat,
payment_hash=payment_hash,
min_final_cltv_expiry=min_cltv_expiry,
payment_secret=payment_secret,
payment_secret=shi.payment_secret_bucket,
trampoline_onion=trampoline_onion)
key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route
payment_key = payment_hash + payment_secret
self.sent_htlcs_info[key] = shi
payment_key = payment_hash + shi.payment_secret_bucket
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and amount_msat != total_msat:
if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat:
if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += amount_receiver_msat
amount_sent += shi.amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed
if self.network.path_finder:
# add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True)
self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
util.trigger_callback('htlc_added', chan, htlc, SENT)
def handle_error_code_from_failed_htlc(
@ -1633,7 +1640,7 @@ class LNWallet(LNWorker):
fwd_trampoline_onion=None,
full_path: LNPaymentPath = None,
channels: Optional[Sequence[Channel]] = None,
) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]:
) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
"""Creates multiple routes for splitting a payment over the available
private channels.
@ -1719,7 +1726,17 @@ class LNWallet(LNWorker):
node_features=trampoline_features)
]
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
routes.append((route, part_amount_msat_with_fees, per_trampoline_amount_with_fees, part_amount_msat, per_trampoline_cltv_delta, per_trampoline_secret, trampoline_onion, trampoline_route))
shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_bucket=per_trampoline_secret,
amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_route=trampoline_route,
)
routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
if per_trampoline_fees != 0:
self.logger.info('not enough margin to pay trampoline fee')
raise NoPathFound()
@ -1741,7 +1758,17 @@ class LNWallet(LNWorker):
full_path=full_path,
)
)
routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion, None))
shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_bucket=payment_secret,
amount_msat=part_amount_msat,
bucket_msat=final_total_msat,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_route=None,
)
routes.append((shi, min_cltv_expiry, fwd_trampoline_onion))
except NoPathFound:
continue
for route in routes:
@ -2096,14 +2123,16 @@ class LNWallet(LNWorker):
def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash)
q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_key = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_key)
if q:
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
htlc_log = HtlcLog(
success=True,
route=route,
amount_msat=amount_receiver_msat,
trampoline_fee_level=trampoline_fee_level)
route=shi.route,
amount_msat=shi.amount_receiver_msat,
trampoline_fee_level=shi.trampoline_fee_level)
q.put_nowait(htlc_log)
else:
key = payment_hash.hex()
@ -2120,12 +2149,16 @@ class LNWallet(LNWorker):
util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash)
q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_okey = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_okey)
if q:
# detect if it is part of a bucket
# if yes, wait until the bucket completely failed
key = (payment_hash, chan.short_channel_id, htlc_id)
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[key]
shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
amount_receiver_msat = shi.amount_receiver_msat
route = shi.route
if error_bytes:
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
try:
@ -2140,19 +2173,19 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline
payment_key = payment_hash + payment_secret
if self.uses_trampoline() and payment_key in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_key]
payment_bkey = payment_hash + shi.payment_secret_bucket
if self.uses_trampoline() and payment_bkey in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_bkey]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed
self.sent_buckets[payment_bkey] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return
self.logger.info('bucket failed')
amount_receiver_msat = amount_sent
if trampoline_route:
route = trampoline_route
if shi.trampoline_route:
route = shi.trampoline_route
htlc_log = HtlcLog(
success=False,
route=route,
@ -2160,7 +2193,7 @@ class LNWallet(LNWorker):
error_bytes=error_bytes,
failure_msg=failure_message,
sender_idx=sender_idx,
trampoline_fee_level=trampoline_fee_level)
trampoline_fee_level=shi.trampoline_fee_level)
q.put_nowait(htlc_log)
else:
self.logger.info(f"received unknown htlc_failed, probably from previous session")

85
electrum/tests/test_lnpeer.py

@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
@ -166,7 +166,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.enable_htlc_settle = True
self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict()
self.sent_htlcs = defaultdict(asyncio.Queue)
self.sent_htlcs_q = defaultdict(asyncio.Queue)
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.trampoline_forwardings = set()
@ -740,7 +740,7 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(SuccessfulTest):
await f()
async def _activate_trampoline(self, w):
async def _activate_trampoline(self, w: MockLNWallet):
if w.network.channel_db:
w.network.channel_db.stop()
await w.network.channel_db.stopped_event.wait()
@ -837,38 +837,44 @@ class TestPeer(ElectrumTestCase):
lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs[lnaddr2.paymenthash]
q2 = w2.sent_htlcs[lnaddr1.paymenthash]
q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret]
q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret]
# alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0]
amount_msat = lnaddr2.get_amount_msat()
await w1.pay_to_route(
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
shi1 = SentHtlcInfo(
route=route1,
amount_msat=amount_msat,
total_msat=amount_msat,
amount_receiver_msat=amount_msat,
payment_secret_orig=lnaddr2.payment_secret,
payment_secret_bucket=lnaddr2.payment_secret,
amount_msat=lnaddr2.get_amount_msat(),
bucket_msat=lnaddr2.get_amount_msat(),
amount_receiver_msat=lnaddr2.get_amount_msat(),
trampoline_fee_level=None,
trampoline_route=None,
)
await w1.pay_to_route(
sent_htlc_info=shi1,
payment_hash=lnaddr2.paymenthash,
min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
payment_secret=lnaddr2.payment_secret,
trampoline_fee_level=0,
trampoline_route=None,
)
p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0]
amount_msat = lnaddr1.get_amount_msat()
await w2.pay_to_route(
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route
shi2 = SentHtlcInfo(
route=route2,
amount_msat=amount_msat,
total_msat=amount_msat,
amount_receiver_msat=amount_msat,
payment_secret_orig=lnaddr1.payment_secret,
payment_secret_bucket=lnaddr1.payment_secret,
amount_msat=lnaddr1.get_amount_msat(),
bucket_msat=lnaddr1.get_amount_msat(),
amount_receiver_msat=lnaddr1.get_amount_msat(),
trampoline_fee_level=None,
trampoline_route=None,
)
await w2.pay_to_route(
sent_htlc_info=shi2,
payment_hash=lnaddr1.paymenthash,
min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
payment_secret=lnaddr1.payment_secret,
trampoline_fee_level=0,
trampoline_route=None,
)
p2.maybe_send_commitment = _maybe_send_commitment2
# sleep a bit so that they both receive msgs sent so far
@ -878,9 +884,9 @@ class TestPeer(ElectrumTestCase):
p2.maybe_send_commitment(bob_channel)
htlc_log1 = await q1.get()
assert htlc_log1.success
self.assertTrue(htlc_log1.success)
htlc_log2 = await q2.get()
assert htlc_log2.success
self.assertTrue(htlc_log2.success)
raise PaymentDone()
async def f():
@ -1184,10 +1190,7 @@ class TestPeer(ElectrumTestCase):
if not bob_forwarding:
graph.workers['bob'].enable_htlc_forwarding = False
if alice_uses_trampoline:
if graph.workers['alice'].network.channel_db:
graph.workers['alice'].network.channel_db.stop()
await graph.workers['alice'].network.channel_db.stopped_event.wait()
graph.workers['alice'].network.channel_db = None
await self._activate_trampoline(graph.workers['alice'])
else:
assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
@ -1433,7 +1436,7 @@ class TestPeer(ElectrumTestCase):
await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1)
# alice sends htlc
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
route = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0].route
p1.pay(route=route,
chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(),
@ -1556,7 +1559,8 @@ class TestPeer(ElectrumTestCase):
lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_invoice(pay_req)
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2]
shi = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0]
route, amount_msat = shi.route, shi.amount_msat
assert amount_msat == lnaddr.get_amount_msat()
await w1.force_close_channel(alice_channel.channel_id)
@ -1570,20 +1574,21 @@ class TestPeer(ElectrumTestCase):
# AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed
async def f():
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret
pay = w1.pay_to_route(
shi = SentHtlcInfo(
route=route,
payment_secret_orig=lnaddr.payment_secret,
payment_secret_bucket=lnaddr.payment_secret,
amount_msat=amount_msat,
total_msat=amount_msat,
bucket_msat=amount_msat,
amount_receiver_msat=amount_msat,
payment_hash=payment_hash,
payment_secret=payment_secret,
min_cltv_expiry=min_cltv_expiry,
trampoline_fee_level=0,
trampoline_fee_level=None,
trampoline_route=None,
)
pay = w1.pay_to_route(
sent_htlc_info=shi,
payment_hash=lnaddr.paymenthash,
min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure):
await f()

Loading…
Cancel
Save