Browse Source

lnworker: move sent_buckets into PaySession

master
SomberNight 2 years ago
parent
commit
98bda60c01
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 70
      electrum/lnworker.py
  2. 14
      electrum/tests/test_lnpeer.py

70
electrum/lnworker.py

@ -667,6 +667,7 @@ class PaySession(Logger):
min_cltv_expiry: int, min_cltv_expiry: int,
amount_to_pay: int, # total payment amount final receiver will get amount_to_pay: int, # total payment amount final receiver will get
invoice_pubkey: bytes, invoice_pubkey: bytes,
uses_trampoline: bool, # whether sender uses trampoline or gossip
): ):
assert payment_hash assert payment_hash
assert payment_secret assert payment_secret
@ -684,9 +685,11 @@ class PaySession(Logger):
self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog] self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog]
self.start_time = time.time() self.start_time = time.time()
self.uses_trampoline = uses_trampoline
self.trampoline_fee_level = initial_trampoline_fee_level self.trampoline_fee_level = initial_trampoline_fee_level
self.failed_trampoline_routes = [] self.failed_trampoline_routes = []
self.use_two_trampolines = True self.use_two_trampolines = True
self._sent_buckets = dict() # psecret_bucket -> (amount_sent, amount_failed)
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0 self._nhtlcs_inflight = 0
@ -742,13 +745,36 @@ class PaySession(Logger):
raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !") raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !")
return htlc_log return htlc_log
def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo: def add_new_htlc(self, sent_htlc_info: SentHtlcInfo):
self._nhtlcs_inflight += 1 self._nhtlcs_inflight += 1
self._amount_inflight += sent_htlc_info.amount_receiver_msat self._amount_inflight += sent_htlc_info.amount_receiver_msat
if self._amount_inflight > self.amount_to_pay: # safety belts if self._amount_inflight > self.amount_to_pay: # safety belts
raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}") raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level) shi = sent_htlc_info
return sent_htlc_info bkey = shi.payment_secret_bucket
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline and shi.amount_msat != shi.bucket_msat:
if bkey not in self._sent_buckets:
self._sent_buckets[bkey] = (0, 0)
amount_sent, amount_failed = self._sent_buckets[bkey]
amount_sent += shi.amount_receiver_msat
self._sent_buckets[bkey] = amount_sent, amount_failed
def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]:
shi = sent_htlc_info
# check sent_buckets if we use trampoline
bkey = shi.payment_secret_bucket
if self.uses_trampoline and bkey in self._sent_buckets:
amount_sent, amount_failed = self._sent_buckets[bkey]
amount_failed += shi.amount_receiver_msat
self._sent_buckets[bkey] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return None
self.logger.info('bucket failed')
return amount_sent
# not using trampoline buckets
return shi.amount_receiver_msat
def get_outstanding_amount_to_send(self) -> int: def get_outstanding_amount_to_send(self) -> int:
return self.amount_to_pay - self._amount_inflight return self.amount_to_pay - self._amount_inflight
@ -795,7 +821,6 @@ class LNWallet(LNWorker):
self._paysessions = dict() # type: Dict[bytes, PaySession] self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
# detect inflight payments # detect inflight payments
@ -1397,6 +1422,7 @@ class LNWallet(LNWorker):
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
amount_to_pay=amount_to_pay, amount_to_pay=amount_to_pay,
invoice_pubkey=node_pubkey, invoice_pubkey=node_pubkey,
uses_trampoline=self.uses_trampoline(),
) )
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding) self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
@ -1417,10 +1443,9 @@ class LNWallet(LNWorker):
) )
# 2. send htlcs # 2. send htlcs
async for sent_htlc_info, cltv_delta, trampoline_onion in routes: async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
sent_htlc_info = paysession.add_new_htlc(sent_htlc_info)
await self.pay_to_route( await self.pay_to_route(
paysession=paysession,
sent_htlc_info=sent_htlc_info, sent_htlc_info=sent_htlc_info,
payment_hash=payment_hash,
min_cltv_expiry=cltv_delta, min_cltv_expiry=cltv_delta,
trampoline_onion=trampoline_onion, trampoline_onion=trampoline_onion,
) )
@ -1466,8 +1491,8 @@ class LNWallet(LNWorker):
async def pay_to_route( async def pay_to_route(
self, *, self, *,
paysession: PaySession,
sent_htlc_info: SentHtlcInfo, sent_htlc_info: SentHtlcInfo,
payment_hash: bytes,
min_cltv_expiry: int, min_cltv_expiry: int,
trampoline_onion: bytes = None, trampoline_onion: bytes = None,
) -> None: ) -> None:
@ -1486,21 +1511,14 @@ class LNWallet(LNWorker):
chan=chan, chan=chan,
amount_msat=shi.amount_msat, amount_msat=shi.amount_msat,
total_msat=shi.bucket_msat, total_msat=shi.bucket_msat,
payment_hash=payment_hash, payment_hash=paysession.payment_hash,
min_final_cltv_expiry=min_cltv_expiry, min_final_cltv_expiry=min_cltv_expiry,
payment_secret=shi.payment_secret_bucket, payment_secret=shi.payment_secret_bucket,
trampoline_onion=trampoline_onion) trampoline_onion=trampoline_onion)
key = (payment_hash, short_channel_id, htlc.htlc_id) key = (paysession.payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = shi self.sent_htlcs_info[key] = shi
payment_key = payment_hash + shi.payment_secret_bucket paysession.add_new_htlc(shi)
# if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat:
if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += shi.amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed
if self.network.path_finder: if self.network.path_finder:
# add inflight htlcs to liquidity hints # add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True) self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
@ -1807,7 +1825,7 @@ class LNWallet(LNWorker):
amount_msat=part_amount_msat_with_fees, amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees, bucket_msat=per_trampoline_amount_with_fees,
amount_receiver_msat=part_amount_msat, amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None, trampoline_fee_level=paysession.trampoline_fee_level,
trampoline_route=trampoline_route, trampoline_route=trampoline_route,
) )
routes.append((shi, per_trampoline_cltv_delta, trampoline_onion)) routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
@ -2232,7 +2250,6 @@ class LNWallet(LNWorker):
# detect if it is part of a bucket # detect if it is part of a bucket
# if yes, wait until the bucket completely failed # if yes, wait until the bucket completely failed
shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)] shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
amount_receiver_msat = shi.amount_receiver_msat
route = shi.route route = shi.route
if error_bytes: if error_bytes:
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
@ -2247,18 +2264,9 @@ class LNWallet(LNWorker):
sender_idx = None sender_idx = None
self.logger.info(f"htlc_failed {failure_message}") self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi)
payment_bkey = payment_hash + shi.payment_secret_bucket if amount_receiver_msat is None:
if self.uses_trampoline() and payment_bkey in self.sent_buckets: return
amount_sent, amount_failed = self.sent_buckets[payment_bkey]
amount_failed += amount_receiver_msat
self.sent_buckets[payment_bkey] = amount_sent, amount_failed
if amount_sent != amount_failed:
self.logger.info('bucket still active...')
return
self.logger.info('bucket failed')
amount_receiver_msat = amount_sent
if shi.trampoline_route: if shi.trampoline_route:
route = shi.trampoline_route route = shi.trampoline_route
htlc_log = HtlcLog( htlc_log = HtlcLog(

14
electrum/tests/test_lnpeer.py

@ -241,6 +241,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
amount_to_pay=amount_msat, amount_to_pay=amount_msat,
invoice_pubkey=decoded_invoice.pubkey.serialize(), invoice_pubkey=decoded_invoice.pubkey.serialize(),
uses_trampoline=False,
) )
paysession.use_two_trampolines = False paysession.use_two_trampolines = False
payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret
@ -861,6 +862,7 @@ class TestPeer(ElectrumTestCase):
# alice sends htlc BUT NOT COMMITMENT_SIGNED # alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
paysession1 = w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret]
shi1 = SentHtlcInfo( shi1 = SentHtlcInfo(
route=route1, route=route1,
payment_secret_orig=lnaddr2.payment_secret, payment_secret_orig=lnaddr2.payment_secret,
@ -873,13 +875,14 @@ class TestPeer(ElectrumTestCase):
) )
await w1.pay_to_route( await w1.pay_to_route(
sent_htlc_info=shi1, sent_htlc_info=shi1,
payment_hash=lnaddr2.paymenthash, paysession=paysession1,
min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
) )
p1.maybe_send_commitment = _maybe_send_commitment1 p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED # bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None p2.maybe_send_commitment = lambda x: None
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route
paysession2 = w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret]
shi2 = SentHtlcInfo( shi2 = SentHtlcInfo(
route=route2, route=route2,
payment_secret_orig=lnaddr1.payment_secret, payment_secret_orig=lnaddr1.payment_secret,
@ -892,7 +895,7 @@ class TestPeer(ElectrumTestCase):
) )
await w2.pay_to_route( await w2.pay_to_route(
sent_htlc_info=shi2, sent_htlc_info=shi2,
payment_hash=lnaddr1.paymenthash, paysession=paysession2,
min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
) )
p2.maybe_send_commitment = _maybe_send_commitment2 p2.maybe_send_commitment = _maybe_send_commitment2
@ -902,9 +905,9 @@ class TestPeer(ElectrumTestCase):
p1.maybe_send_commitment(alice_channel) p1.maybe_send_commitment(alice_channel)
p2.maybe_send_commitment(bob_channel) p2.maybe_send_commitment(bob_channel)
htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get() htlc_log1 = await paysession1.sent_htlcs_q.get()
self.assertTrue(htlc_log1.success) self.assertTrue(htlc_log1.success)
htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get() htlc_log2 = await paysession2.sent_htlcs_q.get()
self.assertTrue(htlc_log2.success) self.assertTrue(htlc_log2.success)
raise PaymentDone() raise PaymentDone()
@ -1603,9 +1606,10 @@ class TestPeer(ElectrumTestCase):
trampoline_fee_level=None, trampoline_fee_level=None,
trampoline_route=None, trampoline_route=None,
) )
paysession = w1._paysessions[lnaddr.paymenthash + lnaddr.payment_secret]
pay = w1.pay_to_route( pay = w1.pay_to_route(
sent_htlc_info=shi, sent_htlc_info=shi,
payment_hash=lnaddr.paymenthash, paysession=paysession,
min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
) )
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())

Loading…
Cancel
Save