diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 9f3eae07a..b08f8b3df 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -667,6 +667,7 @@ class PaySession(Logger): min_cltv_expiry: int, amount_to_pay: int, # total payment amount final receiver will get invoice_pubkey: bytes, + uses_trampoline: bool, # whether sender uses trampoline or gossip ): assert payment_hash assert payment_secret @@ -684,9 +685,11 @@ class PaySession(Logger): self.sent_htlcs_q = asyncio.Queue() # type: asyncio.Queue[HtlcLog] self.start_time = time.time() + self.uses_trampoline = uses_trampoline self.trampoline_fee_level = initial_trampoline_fee_level self.failed_trampoline_routes = [] self.use_two_trampolines = True + self._sent_buckets = dict() # psecret_bucket -> (amount_sent, amount_failed) self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) self._nhtlcs_inflight = 0 @@ -742,13 +745,36 @@ class PaySession(Logger): raise Exception(f"amount_inflight={self._amount_inflight}, nhtlcs_inflight={self._nhtlcs_inflight}. both should be >= 0 !") return htlc_log - def add_new_htlc(self, sent_htlc_info: SentHtlcInfo) -> SentHtlcInfo: + def add_new_htlc(self, sent_htlc_info: SentHtlcInfo): self._nhtlcs_inflight += 1 self._amount_inflight += sent_htlc_info.amount_receiver_msat if self._amount_inflight > self.amount_to_pay: # safety belts raise Exception(f"amount_inflight={self._amount_inflight} > amount_to_pay={self.amount_to_pay}") - sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level) - return sent_htlc_info + shi = sent_htlc_info + bkey = shi.payment_secret_bucket + # if we sent MPP to a trampoline, add item to sent_buckets + if self.uses_trampoline and shi.amount_msat != shi.bucket_msat: + if bkey not in self._sent_buckets: + self._sent_buckets[bkey] = (0, 0) + amount_sent, amount_failed = self._sent_buckets[bkey] + amount_sent += shi.amount_receiver_msat + self._sent_buckets[bkey] = amount_sent, amount_failed + + def on_htlc_fail_get_fail_amt_to_propagate(self, sent_htlc_info: SentHtlcInfo) -> Optional[int]: + shi = sent_htlc_info + # check sent_buckets if we use trampoline + bkey = shi.payment_secret_bucket + if self.uses_trampoline and bkey in self._sent_buckets: + amount_sent, amount_failed = self._sent_buckets[bkey] + amount_failed += shi.amount_receiver_msat + self._sent_buckets[bkey] = amount_sent, amount_failed + if amount_sent != amount_failed: + self.logger.info('bucket still active...') + return None + self.logger.info('bucket failed') + return amount_sent + # not using trampoline buckets + return shi.amount_receiver_msat def get_outstanding_amount_to_send(self) -> int: return self.amount_to_pay - self._amount_inflight @@ -795,7 +821,6 @@ class LNWallet(LNWorker): self._paysessions = dict() # type: Dict[bytes, PaySession] self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo] - self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) # TODO move into PaySession self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus # detect inflight payments @@ -1397,6 +1422,7 @@ class LNWallet(LNWorker): min_cltv_expiry=min_cltv_expiry, amount_to_pay=amount_to_pay, invoice_pubkey=node_pubkey, + uses_trampoline=self.uses_trampoline(), ) self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding) @@ -1417,10 +1443,9 @@ class LNWallet(LNWorker): ) # 2. send htlcs async for sent_htlc_info, cltv_delta, trampoline_onion in routes: - sent_htlc_info = paysession.add_new_htlc(sent_htlc_info) await self.pay_to_route( + paysession=paysession, sent_htlc_info=sent_htlc_info, - payment_hash=payment_hash, min_cltv_expiry=cltv_delta, trampoline_onion=trampoline_onion, ) @@ -1466,8 +1491,8 @@ class LNWallet(LNWorker): async def pay_to_route( self, *, + paysession: PaySession, sent_htlc_info: SentHtlcInfo, - payment_hash: bytes, min_cltv_expiry: int, trampoline_onion: bytes = None, ) -> None: @@ -1486,21 +1511,14 @@ class LNWallet(LNWorker): chan=chan, amount_msat=shi.amount_msat, total_msat=shi.bucket_msat, - payment_hash=payment_hash, + payment_hash=paysession.payment_hash, min_final_cltv_expiry=min_cltv_expiry, payment_secret=shi.payment_secret_bucket, trampoline_onion=trampoline_onion) - key = (payment_hash, short_channel_id, htlc.htlc_id) + key = (paysession.payment_hash, short_channel_id, htlc.htlc_id) self.sent_htlcs_info[key] = shi - payment_key = payment_hash + shi.payment_secret_bucket - # if we sent MPP to a trampoline, add item to sent_buckets - if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat: - if payment_key not in self.sent_buckets: - self.sent_buckets[payment_key] = (0, 0) - amount_sent, amount_failed = self.sent_buckets[payment_key] - amount_sent += shi.amount_receiver_msat - self.sent_buckets[payment_key] = amount_sent, amount_failed + paysession.add_new_htlc(shi) if self.network.path_finder: # add inflight htlcs to liquidity hints self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True) @@ -1807,7 +1825,7 @@ class LNWallet(LNWorker): amount_msat=part_amount_msat_with_fees, bucket_msat=per_trampoline_amount_with_fees, amount_receiver_msat=part_amount_msat, - trampoline_fee_level=None, + trampoline_fee_level=paysession.trampoline_fee_level, trampoline_route=trampoline_route, ) routes.append((shi, per_trampoline_cltv_delta, trampoline_onion)) @@ -2232,7 +2250,6 @@ class LNWallet(LNWorker): # detect if it is part of a bucket # if yes, wait until the bucket completely failed shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)] - amount_receiver_msat = shi.amount_receiver_msat route = shi.route if error_bytes: # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? @@ -2247,18 +2264,9 @@ class LNWallet(LNWorker): sender_idx = None self.logger.info(f"htlc_failed {failure_message}") - # check sent_buckets if we use trampoline - payment_bkey = payment_hash + shi.payment_secret_bucket - if self.uses_trampoline() and payment_bkey in self.sent_buckets: - amount_sent, amount_failed = self.sent_buckets[payment_bkey] - amount_failed += amount_receiver_msat - self.sent_buckets[payment_bkey] = amount_sent, amount_failed - if amount_sent != amount_failed: - self.logger.info('bucket still active...') - return - self.logger.info('bucket failed') - amount_receiver_msat = amount_sent - + amount_receiver_msat = paysession.on_htlc_fail_get_fail_amt_to_propagate(shi) + if amount_receiver_msat is None: + return if shi.trampoline_route: route = shi.trampoline_route htlc_log = HtlcLog( diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index ee6f970a7..c9ed07fbd 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -241,6 +241,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), amount_to_pay=amount_msat, invoice_pubkey=decoded_invoice.pubkey.serialize(), + uses_trampoline=False, ) paysession.use_two_trampolines = False payment_key = decoded_invoice.paymenthash + decoded_invoice.payment_secret @@ -861,6 +862,7 @@ class TestPeer(ElectrumTestCase): # alice sends htlc BUT NOT COMMITMENT_SIGNED p1.maybe_send_commitment = lambda x: None route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route + paysession1 = w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret] shi1 = SentHtlcInfo( route=route1, payment_secret_orig=lnaddr2.payment_secret, @@ -873,13 +875,14 @@ class TestPeer(ElectrumTestCase): ) await w1.pay_to_route( sent_htlc_info=shi1, - payment_hash=lnaddr2.paymenthash, + paysession=paysession1, min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), ) p1.maybe_send_commitment = _maybe_send_commitment1 # bob sends htlc BUT NOT COMMITMENT_SIGNED p2.maybe_send_commitment = lambda x: None route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route + paysession2 = w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret] shi2 = SentHtlcInfo( route=route2, payment_secret_orig=lnaddr1.payment_secret, @@ -892,7 +895,7 @@ class TestPeer(ElectrumTestCase): ) await w2.pay_to_route( sent_htlc_info=shi2, - payment_hash=lnaddr1.paymenthash, + paysession=paysession2, min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), ) p2.maybe_send_commitment = _maybe_send_commitment2 @@ -902,9 +905,9 @@ class TestPeer(ElectrumTestCase): p1.maybe_send_commitment(alice_channel) p2.maybe_send_commitment(bob_channel) - htlc_log1 = await w1._paysessions[lnaddr2.paymenthash + lnaddr2.payment_secret].sent_htlcs_q.get() + htlc_log1 = await paysession1.sent_htlcs_q.get() self.assertTrue(htlc_log1.success) - htlc_log2 = await w2._paysessions[lnaddr1.paymenthash + lnaddr1.payment_secret].sent_htlcs_q.get() + htlc_log2 = await paysession2.sent_htlcs_q.get() self.assertTrue(htlc_log2.success) raise PaymentDone() @@ -1603,9 +1606,10 @@ class TestPeer(ElectrumTestCase): trampoline_fee_level=None, trampoline_route=None, ) + paysession = w1._paysessions[lnaddr.paymenthash + lnaddr.payment_secret] pay = w1.pay_to_route( sent_htlc_info=shi, - payment_hash=lnaddr.paymenthash, + paysession=paysession, min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), ) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())