Browse Source

lnworker: introduce PaySession cls, refactor pay_to_node

master
SomberNight 2 years ago
parent
commit
00e88c4e50
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 274
      electrum/lnworker.py
  2. 35
      electrum/tests/test_lnpeer.py

274
electrum/lnworker.py

@ -17,7 +17,7 @@ import socket
import aiohttp
import json
from datetime import datetime, timezone
from functools import partial
from functools import partial, cached_property
from collections import defaultdict
import concurrent
from concurrent import futures
@ -655,6 +655,105 @@ class LNGossip(LNWorker):
self.logger.debug(f'process_gossip: {len(categorized_chan_upds.good)}/{len(chan_upds)}')
class PaySession(Logger):
def __init__(
self,
*,
payment_hash: bytes,
payment_secret: bytes,
initial_trampoline_fee_level: int,
invoice_features: int,
r_tags,
min_cltv_expiry: int,
amount_to_pay: int, # total payment amount final receiver will get
invoice_pubkey: bytes,
):
assert payment_hash
assert payment_secret
self.payment_hash = payment_hash
self.payment_secret = payment_secret
self.payment_key = payment_hash + payment_secret
Logger.__init__(self)
self.invoice_features = LnFeatures(invoice_features)
self.r_tags = r_tags
self.min_cltv_expiry = min_cltv_expiry
self.amount_to_pay = amount_to_pay
self.invoice_pubkey = invoice_pubkey
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
self.start_time = time.time()
self.trampoline_fee_level = initial_trampoline_fee_level
self.failed_trampoline_routes = []
self.use_two_trampolines = True
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0
def diagnostic_name(self):
pkey = sha256(self.payment_key)
return f"{self.payment_hash[:4].hex()}-{pkey[:2].hex()}"
def maybe_raise_trampoline_fee(self, htlc_log: HtlcLog):
if htlc_log.trampoline_fee_level == self.trampoline_fee_level:
self.trampoline_fee_level += 1
self.failed_trampoline_routes = []
self.logger.info(f'raising trampoline fee level {self.trampoline_fee_level}')
else:
self.logger.info(f'NOT raising trampoline fee level, already at {self.trampoline_fee_level}')
def handle_failed_trampoline_htlc(self, *, htlc_log: HtlcLog, failure_msg: OnionRoutingFailure):
# FIXME The trampoline nodes in the path are chosen randomly.
# Some of the errors might depend on how we have chosen them.
# Having more attempts is currently useful in part because of the randomness,
# instead we should give feedback to create_routes_for_payment.
# Sometimes the trampoline node fails to send a payment and returns
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
if failure_msg.code in (
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
# TODO: parse the node policy here (not returned by eclair yet)
# TODO: erring node is always the first trampoline even if second
# trampoline demands more fees, we can't influence this
self.maybe_raise_trampoline_fee(htlc_log)
elif self.use_two_trampolines:
self.use_two_trampolines = False
elif failure_msg.code in (
OnionFailureCode.UNKNOWN_NEXT_PEER,
OnionFailureCode.TEMPORARY_NODE_FAILURE):
trampoline_route = htlc_log.route
r = [hop.end_node.hex() for hop in trampoline_route]
self.logger.info(f'failed trampoline route: {r}')
if r not in self.failed_trampoline_routes:
self.failed_trampoline_routes.append(r)
else:
pass # maybe the route was reused between different MPP parts
else:
raise PaymentFailure(failure_msg.code_name())
async def wait_for_one_htlc_to_resolve(self) -> HtlcLog:
self.logger.info(f"waiting... amount_inflight={self._amount_inflight}. nhtlcs_inflight={self._nhtlcs_inflight}")
htlc_log = await self.sent_htlcs_q.get()
self._amount_inflight -= htlc_log.amount_msat
self._nhtlcs_inflight -= 1
if self._amount_inflight < 0 or self._nhtlcs_inflight < 0:
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
return htlc_log
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo:
self._nhtlcs_inflight += 1
self._amount_inflight += sent_htlc_info.amount_receiver_msat
if self._amount_inflight > self.amount_to_pay: # safety belts
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
return sent_htlc_info
def get_outstanding_amount_to_send(self) -> int:
return self.amount_to_pay - self._amount_inflight
class LNWallet(LNWorker):
lnwatcher: Optional['LNWalletWatcher']
@ -694,9 +793,9 @@ class LNWallet(LNWorker):
for channel_id, storage in channel_backups.items():
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments
@ -1274,9 +1373,9 @@ class LNWallet(LNWorker):
invoice_features: int,
attempts: int = None,
full_path: LNPaymentPath = None,
fwd_trampoline_onion=None,
fwd_trampoline_fee=None,
fwd_trampoline_cltv_delta=None,
fwd_trampoline_onion: OnionPacket = None,
fwd_trampoline_fee: int = None,
fwd_trampoline_cltv_delta: int = None,
channels: Optional[Sequence[Channel]] = None,
) -> None:
@ -1288,46 +1387,37 @@ class LNWallet(LNWorker):
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
payment_key = payment_hash + payment_secret
#assert payment_key not in self._paysessions # FIXME
self._paysessions[payment_key] = paysession = PaySession(
payment_hash=payment_hash,
payment_secret=payment_secret,
initial_trampoline_fee_level=self.INITIAL_TRAMPOLINE_FEE_LEVEL,
invoice_features=invoice_features,
r_tags=r_tags,
min_cltv_expiry=min_cltv_expiry,
amount_to_pay=amount_to_pay,
invoice_pubkey=node_pubkey,
)
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
# when encountering trampoline forwarding difficulties in the legacy case, we
# sometimes need to fall back to a single trampoline forwarder, at the expense
# of privacy
use_two_trampolines = True
trampoline_fee_level = self.INITIAL_TRAMPOLINE_FEE_LEVEL
failed_trampoline_routes = []
start_time = time.time()
amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
nhtlcs_inflight = 0
while True:
amount_to_send = amount_to_pay - amount_inflight
if amount_to_send > 0:
if (amount_to_send := paysession.get_outstanding_amount_to_send()) > 0:
# 1. create a set of routes for remaining amount.
# note: path-finding runs in a separate thread so that we don't block the asyncio loop
# graph updates might occur during the computation
routes = self.create_routes_for_payment(
paysession=paysession,
amount_msat=amount_to_send,
final_total_msat=amount_to_pay,
invoice_pubkey=node_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
full_path=full_path,
payment_hash=payment_hash,
payment_secret=payment_secret,
trampoline_fee_level=trampoline_fee_level,
failed_trampoline_routes=failed_trampoline_routes,
use_two_trampolines=use_two_trampolines,
fwd_trampoline_onion=fwd_trampoline_onion,
channels=channels,
)
# 2. send htlcs
async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
nhtlcs_inflight += 1
amount_inflight += sent_htlc_info.amount_receiver_msat
if amount_inflight > amount_to_pay: # safety belts
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=trampoline_fee_level)
sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
await self.pay_to_route(
sent_htlc_info=sent_htlc_info,
payment_hash=payment_hash,
@ -1339,12 +1429,7 @@ class LNWallet(LNWorker):
# (e.g. attempt counter)
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
# 3. await a queue
self.logger.info(f"paysession for RHASH {payment_hash.hex()} waiting... {amount_inflight=}. {nhtlcs_inflight=}")
htlc_log = await self.sent_htlcs_q[payment_key].get() # TODO maybe wait a bit, more failures might come
amount_inflight -= htlc_log.amount_msat
nhtlcs_inflight -= 1
if amount_inflight < 0 or nhtlcs_inflight < 0:
raise Exception(f"{amount_inflight=}, {nhtlcs_inflight=}. both should be >= 0 !")
htlc_log = await paysession.wait_for_one_htlc_to_resolve() # TODO maybe wait a bit, more failures might come
log.append(htlc_log)
if htlc_log.success:
if self.network.path_finder:
@ -1357,7 +1442,7 @@ class LNWallet(LNWorker):
self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False)
return
# htlc failed
if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - start_time > self.PAYMENT_TIMEOUT):
if (attempts is not None and len(log) >= attempts) or (attempts is None and time.time() - paysession.start_time > self.PAYMENT_TIMEOUT):
raise PaymentFailure('Giving up after %d attempts'%len(log))
# if we get a tmp channel failure, it might work to split the amount and try more routes
# if we get a channel update, we might retry the same route and amount
@ -1373,45 +1458,8 @@ class LNWallet(LNWorker):
raise PaymentFailure(failure_msg.code_name())
# trampoline
if self.uses_trampoline():
def maybe_raise_trampoline_fee(htlc_log):
nonlocal trampoline_fee_level
nonlocal failed_trampoline_routes
if htlc_log.trampoline_fee_level == trampoline_fee_level:
trampoline_fee_level += 1
failed_trampoline_routes = []
self.logger.info(f'raising trampoline fee level {trampoline_fee_level}')
else:
self.logger.info(f'NOT raising trampoline fee level, already at {trampoline_fee_level}')
# FIXME The trampoline nodes in the path are chosen randomly.
# Some of the errors might depend on how we have chosen them.
# Having more attempts is currently useful in part because of the randomness,
# instead we should give feedback to create_routes_for_payment.
# Sometimes the trampoline node fails to send a payment and returns
# TEMPORARY_CHANNEL_FAILURE, while it succeeds with a higher trampoline fee.
if code in (
OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON,
OnionFailureCode.TEMPORARY_CHANNEL_FAILURE):
# TODO: parse the node policy here (not returned by eclair yet)
# TODO: erring node is always the first trampoline even if second
# trampoline demands more fees, we can't influence this
maybe_raise_trampoline_fee(htlc_log)
continue
elif use_two_trampolines:
use_two_trampolines = False
elif code in (
OnionFailureCode.UNKNOWN_NEXT_PEER,
OnionFailureCode.TEMPORARY_NODE_FAILURE):
trampoline_route = htlc_log.route
r = [hop.end_node.hex() for hop in trampoline_route]
self.logger.info(f'failed trampoline route: {r}')
if r not in failed_trampoline_routes:
failed_trampoline_routes.append(r)
else:
pass # maybe the route was reused between different MPP parts
continue
else:
raise PaymentFailure(failure_msg.code_name())
paysession.handle_failed_trampoline_htlc(
htlc_log=htlc_log, failure_msg=failure_msg)
else:
self.handle_error_code_from_failed_htlc(
route=route, sender_idx=sender_idx, failure_msg=failure_msg, amount=htlc_log.amount_msat)
@ -1654,18 +1702,9 @@ class LNWallet(LNWorker):
async def create_routes_for_payment(
self, *,
paysession: PaySession,
amount_msat: int, # part of payment amount we want routes for now
final_total_msat: int, # total payment amount final receiver will get
invoice_pubkey,
min_cltv_expiry,
r_tags,
invoice_features: int,
payment_hash: bytes,
payment_secret: bytes,
trampoline_fee_level: int,
failed_trampoline_routes: Iterable[Sequence[str]],
use_two_trampolines: bool,
fwd_trampoline_onion=None,
fwd_trampoline_onion: OnionPacket = None,
full_path: LNPaymentPath = None,
channels: Optional[Sequence[Channel]] = None,
) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
@ -1675,7 +1714,6 @@ class LNWallet(LNWorker):
We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment."""
invoice_features = LnFeatures(invoice_features)
trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height()
if channels:
@ -1688,15 +1726,15 @@ class LNWallet(LNWorker):
random.shuffle(my_active_channels)
split_configurations = self.suggest_splits(
amount_msat=amount_msat,
final_total_msat=final_total_msat,
final_total_msat=paysession.amount_to_pay,
my_active_channels=my_active_channels,
invoice_features=invoice_features,
r_tags=r_tags,
invoice_features=paysession.invoice_features,
r_tags=paysession.r_tags,
)
for sc in split_configurations:
is_multichan_mpp = len(sc.config.items()) > 1
is_mpp = sum(len(x) for x in list(sc.config.values())) > 1
if is_mpp and not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
continue
if not is_mpp and self.config.TEST_FORCE_MPP:
continue
@ -1715,33 +1753,33 @@ class LNWallet(LNWorker):
# for each trampoline forwarder, construct mpp trampoline
for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
per_trampoline_amount = sum([x[1] for x in trampoline_parts])
if trampoline_node_id == invoice_pubkey:
if trampoline_node_id == paysession.invoice_pubkey:
trampoline_route = None
trampoline_onion = None
per_trampoline_secret = payment_secret
per_trampoline_secret = paysession.payment_secret
per_trampoline_amount_with_fees = amount_msat
per_trampoline_cltv_delta = min_cltv_expiry
per_trampoline_cltv_delta = paysession.min_cltv_expiry
per_trampoline_fees = 0
else:
trampoline_route, trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
amount_msat=per_trampoline_amount,
total_msat=final_total_msat,
min_cltv_expiry=min_cltv_expiry,
total_msat=paysession.amount_to_pay,
min_cltv_expiry=paysession.min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features,
invoice_pubkey=paysession.invoice_pubkey,
invoice_features=paysession.invoice_features,
node_id=trampoline_node_id,
r_tags=r_tags,
payment_hash=payment_hash,
payment_secret=payment_secret,
r_tags=paysession.r_tags,
payment_hash=paysession.payment_hash,
payment_secret=paysession.payment_secret,
local_height=local_height,
trampoline_fee_level=trampoline_fee_level,
use_two_trampolines=use_two_trampolines,
failed_routes=failed_trampoline_routes)
trampoline_fee_level=paysession.trampoline_fee_level,
use_two_trampolines=paysession.use_two_trampolines,
failed_routes=paysession.failed_trampoline_routes)
# node_features is only used to determine is_tlv
per_trampoline_secret = os.urandom(32)
per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
self.logger.info(f'created route with trampoline fee level={trampoline_fee_level}')
self.logger.info(f'created route with trampoline fee level={paysession.trampoline_fee_level}')
self.logger.info(f'trampoline hops: {[hop.end_node.hex() for hop in trampoline_route]}')
self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
for chan_id, part_amount_msat in trampoline_parts:
@ -1764,7 +1802,7 @@ class LNWallet(LNWorker):
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_orig=paysession.payment_secret,
payment_secret_bucket=per_trampoline_secret,
amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees,
@ -1786,25 +1824,25 @@ class LNWallet(LNWorker):
partial(
self.create_route_for_payment,
amount_msat=part_amount_msat,
invoice_pubkey=invoice_pubkey,
min_cltv_expiry=min_cltv_expiry,
r_tags=r_tags,
invoice_features=invoice_features,
invoice_pubkey=paysession.invoice_pubkey,
min_cltv_expiry=paysession.min_cltv_expiry,
r_tags=paysession.r_tags,
invoice_features=paysession.invoice_features,
my_sending_channels=[channel] if is_multichan_mpp else my_active_channels,
full_path=full_path,
)
)
shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_bucket=payment_secret,
payment_secret_orig=paysession.payment_secret,
payment_secret_bucket=paysession.payment_secret,
amount_msat=part_amount_msat,
bucket_msat=final_total_msat,
bucket_msat=paysession.amount_to_pay,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_route=None,
)
routes.append((shi, min_cltv_expiry, fwd_trampoline_onion))
routes.append((shi, paysession.min_cltv_expiry, fwd_trampoline_onion))
except NoPathFound:
continue
for route in routes:
@ -2159,7 +2197,9 @@ class LNWallet(LNWorker):
q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_key = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_key)
paysession = self._paysessions.get(payment_key)
if paysession:
q = paysession.sent_htlcs_q
if q:
htlc_log = HtlcLog(
success=True,
@ -2185,7 +2225,9 @@ class LNWallet(LNWorker):
q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_okey = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_okey)
paysession = self._paysessions.get(payment_okey)
if paysession:
q = paysession.sent_htlcs_q
if q:
# detect if it is part of a bucket
# if yes, wait until the bucket completely failed

35
electrum/tests/test_lnpeer.py

@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo
from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo, PaySession
from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger
@ -168,7 +168,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.enable_htlc_settle = True
self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict()
self.sent_htlcs_q = defaultdict(asyncio.Queue)
self._paysessions = dict()
self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set)
self.final_onion_forwardings = set()
@ -232,18 +232,22 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
await self.channel_db.stopped_event.wait()
async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None):
return [r async for r in self.create_routes_for_payment(
amount_msat=amount_msat,
final_total_msat=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0,
failed_trampoline_routes=[],
use_two_trampolines=False,
paysession = PaySession(
payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret,
initial_trampoline_fee_level=0,
invoice_features=decoded_invoice.get_features(),
r_tags=decoded_invoice.get_routing_info('r'),
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
amount_to_pay=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
)
paysession.use_two_trampolines = False
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
self._paysessions[payment_key] = paysession
return [r async for r in self.create_routes_for_payment(
amount_msat=amount_msat,
paysession=paysession,
full_path=full_path)]
get_payments = LNWallet.get_payments
@ -854,9 +858,6 @@ class TestPeer(ElectrumTestCase):
_maybe_send_commitment2 = p2.maybe_send_commitment
lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret]
q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret]
# alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
@ -901,9 +902,9 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel)
htlc_log1 = await q1.get()
htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get()
self.assertTrue(htlc_log1.success)
htlc_log2 = await q2.get()
htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get()
self.assertTrue(htlc_log2.success)
raise PaymentDone()

Loading…
Cancel
Save