Browse Source

lnworker: move sent_buckets into PaySession

master
SomberNight 2 years ago
parent
commit
98bda60c01
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 70
      electrum/lnworker.py
  2. 14
      electrum/tests/test_lnpeer.py

70
electrum/lnworker.py

@ -667,6 +667,7 @@ class PaySession(Logger):
min_cltv_expiry: int,
amount_to_pay: int, # total payment amount final receiver will get
invoice_pubkey: bytes,
uses_trampoline: bool, # whether sender uses trampoline or gossip
):
assert payment_hash
assert payment_secret
@ -684,9 +685,11 @@ class PaySession(Logger):
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
self.start_time = time.time()
self.uses_trampoline = uses_trampoline
self.trampoline_fee_level = initial_trampoline_fee_level
self.failed_trampoline_routes = []
self.use_two_trampolines = True
self._sent_buckets = dict() # psecret_bucket -> (amount_sent, amount_failed)
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0
@ -742,13 +745,36 @@ class PaySession(Logger):
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
return htlc_log
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo:
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
self._nhtlcs_inflight += 1
self._amount_inflight += sent_htlc_info.amount_receiver_msat
if self._amount_inflight > self.amount_to_pay: # safety belts
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
return sent_htlc_info
shi = sent_htlc_info
bkey = shi.payment_secret_bucket
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
if bkey not in self._sent_buckets:
self._sent_buckets[bkey] = (0, 0)
amount_sent, amount_failed = self._sent_buckets[bkey]
amount_sent += shi.amount_receiver_msat
self._sent_buckets[bkey] = amount_sent, amount_failed
def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
shi = sent_htlc_info
# check sent_buckets if we use trampoline
bkey = shi.payment_secret_bucket
if self.uses_trampoline and bkey in self._sent_buckets:
amount_sent, amount_failed = self._sent_buckets[bkey]
amount_failed += shi.amount_receiver_msat
self._sent_buckets[bkey] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return None
self.logger.info('bucket failed')
return amount_sent
# not using trampoline buckets
return shi.amount_receiver_msat
def get_outstanding_amount_to_send(self) -> int:
return self.amount_to_pay - self._amount_inflight
@ -795,7 +821,6 @@ class LNWallet(LNWorker):
self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments
@ -1397,6 +1422,7 @@ class LNWallet(LNWorker):
min_cltv_expiry=min_cltv_expiry,
amount_to_pay=amount_to_pay,
invoice_pubkey=node_pubkey,
uses_trampoline=self.uses_trampoline(),
)
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
@ -1417,10 +1443,9 @@ class LNWallet(LNWorker):
)
# 2. send htlcs
async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
await self.pay_to_route(
paysession=paysession,
sent_htlc_info=sent_htlc_info,
payment_hash=payment_hash,
min_cltv_expiry=cltv_delta,
trampoline_onion=trampoline_onion,
)
@ -1466,8 +1491,8 @@ class LNWallet(LNWorker):
async def pay_to_route(
self, *,
paysession: PaySession,
sent_htlc_info: SentHtlcInfo,
payment_hash: bytes,
min_cltv_expiry: int,
trampoline_onion: bytes = None,
) -> None:
@ -1486,21 +1511,14 @@ class LNWallet(LNWorker):
chan=chan,
amount_msat=shi.amount_msat,
total_msat=shi.bucket_msat,
payment_hash=payment_hash,
payment_hash=paysession.payment_hash,
min_final_cltv_expiry=min_cltv_expiry,
payment_secret=shi.payment_secret_bucket,
trampoline_onion=trampoline_onion)
key = (payment_hash, short_channel_id, htlc.htlc_id)
key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = shi
payment_key = payment_hash + shi.payment_secret_bucket
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat:
if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += shi.amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed
paysession.add_new_htlc(shi)
if self.network.path_finder:
# add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
@ -1807,7 +1825,7 @@ class LNWallet(LNWorker):
amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_fee_level=paysession.trampoline_fee_level,
trampoline_route=trampoline_route,
)
routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
@ -2232,7 +2250,6 @@ class LNWallet(LNWorker):
# detect if it is part of a bucket
# if yes, wait until the bucket completely failed
shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
amount_receiver_msat = shi.amount_receiver_msat
route = shi.route
if error_bytes:
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
@ -2247,18 +2264,9 @@ class LNWallet(LNWorker):
sender_idx = None
self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline
payment_bkey = payment_hash + shi.payment_secret_bucket
if self.uses_trampoline() and payment_bkey in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_bkey]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_bkey] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return
self.logger.info('bucket failed')
amount_receiver_msat = amount_sent
amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
if amount_receiver_msat is None:
return
if shi.trampoline_route:
route = shi.trampoline_route
htlc_log = HtlcLog(

14
electrum/tests/test_lnpeer.py

@ -241,6 +241,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
amount_to_pay=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(),
uses_trampoline=False,
)
paysession.use_two_trampolines = False
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
@ -861,6 +862,7 @@ class TestPeer(ElectrumTestCase):
# alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
paysession1 = w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret]
shi1 = SentHtlcInfo(
route=route1,
payment_secret_orig=lnaddr2.payment_secret,
@ -873,13 +875,14 @@ class TestPeer(ElectrumTestCase):
)
await w1.pay_to_route(
sent_htlc_info=shi1,
payment_hash=lnaddr2.paymenthash,
paysession=paysession1,
min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
)
p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route
paysession2 = w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret]
shi2 = SentHtlcInfo(
route=route2,
payment_secret_orig=lnaddr1.payment_secret,
@ -892,7 +895,7 @@ class TestPeer(ElectrumTestCase):
)
await w2.pay_to_route(
sent_htlc_info=shi2,
payment_hash=lnaddr1.paymenthash,
paysession=paysession2,
min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
)
p2.maybe_send_commitment = _maybe_send_commitment2
@ -902,9 +905,9 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel)
htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get()
htlc_log1 = await paysession1.sent_htlcs_q.get()
self.assertTrue(htlc_log1.success)
htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get()
htlc_log2 = await paysession2.sent_htlcs_q.get()
self.assertTrue(htlc_log2.success)
raise PaymentDone()
@ -1603,9 +1606,10 @@ class TestPeer(ElectrumTestCase):
trampoline_fee_level=None,
trampoline_route=None,
)
paysession = w1._paysessions[lnaddr.paymenthash + lnaddr.payment_secret]
pay = w1.pay_to_route(
sent_htlc_info=shi,
payment_hash=lnaddr.paymenthash,
paysession=paysession,
min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())

Loading…
Cancel
Save