From 9f84fa9580eeb63547b9bb84d3e09a0dc0606484 Mon Sep 17 00:00:00 2001 From: bitromortac Date: Fri, 19 Feb 2021 08:34:03 +0100 Subject: [PATCH] lnonion: add total_msat to onion payment data --- electrum/lnonion.py | 10 ++++++++-- electrum/lnpeer.py | 11 ++++++++--- electrum/lnworker.py | 8 +++++--- electrum/tests/test_lnpeer.py | 5 ++++- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/electrum/lnonion.py b/electrum/lnonion.py index 359e76e3d..5fc226de1 100644 --- a/electrum/lnonion.py +++ b/electrum/lnonion.py @@ -261,7 +261,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes, hmac=next_hmac) -def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, +def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, total_msat: int, final_cltv: int, *, payment_secret: bytes = None) \ -> Tuple[List[OnionHopsDataSingle], int, int]: """Returns the hops_data to be used for constructing an onion packet, @@ -277,8 +277,14 @@ def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, "amt_to_forward": {"amt_to_forward": amt}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv}, } + # for multipart payments we need to tell the reciever about the total and + # partial amounts if payment_secret is not None: - hop_payload["payment_data"] = {"payment_secret": payment_secret, "total_msat": amt} + hop_payload["payment_data"] = { + "payment_secret": payment_secret, + "total_msat": total_msat, + "amount_msat": amt + } hops_data = [OnionHopsDataSingle(is_tlv_payload=route[-1].has_feature_varonion(), payload=hop_payload)] # payloads, backwards from last hop (but excluding the first edge): diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index a86321302..22d24c008 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1196,7 +1196,7 @@ class Peer(Logger): self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) def pay(self, *, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, - payment_hash: bytes, min_final_cltv_expiry: int, + total_msat: int, payment_hash: bytes, min_final_cltv_expiry: int, payment_secret: bytes = None, fwd_trampoline_onion=None) -> UpdateAddHtlc: assert amount_msat > 0, "amount_msat is not greater zero" assert len(route) > 0 @@ -1206,8 +1206,13 @@ class Peer(Logger): route[0].node_features |= self.features local_height = self.network.get_local_height() final_cltv = local_height + min_final_cltv_expiry - hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv, - payment_secret=payment_secret) + hops_data, amount_msat, cltv = calc_hops_data_for_payment( + route, + amount_msat, + total_msat, + final_cltv, + payment_secret=payment_secret + ) self.logger.info(f"lnpeer.pay len(route)={len(route)}") for i in range(len(route)): self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 120c83026..003cdbe67 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1090,7 +1090,7 @@ class LNWallet(LNWorker): routes = [(route, amount_to_send)] # 2. send htlcs for route, amount_msat in routes: - await self.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion) + await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion) amount_inflight += amount_msat util.trigger_callback('invoice_status', self.wallet, payment_hash.hex()) # 3. await a queue @@ -1106,8 +1106,9 @@ class LNWallet(LNWorker): # if we get a channel update, we might retry the same route and amount self.handle_error_code_from_failed_htlc(htlc_log) - - async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, payment_hash:bytes, payment_secret:bytes, min_cltv_expiry:int, trampoline_onion:bytes =None): + async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int, + total_msat: int, payment_hash: bytes, payment_secret: bytes, + min_cltv_expiry: int, trampoline_onion: bytes=None): # send a single htlc short_channel_id = route[0].short_channel_id chan = self.get_channel_by_short_id(short_channel_id) @@ -1119,6 +1120,7 @@ class LNWallet(LNWorker): route=route, chan=chan, amount_msat=amount_msat, + total_msat=total_msat, payment_hash=payment_hash, min_final_cltv_expiry=min_cltv_expiry, payment_secret=payment_secret, diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 668056813..1a091080d 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -497,6 +497,7 @@ class TestPeer(ElectrumTestCase): route=route1, chan=alice_channel, amount_msat=lnaddr2.get_amount_msat(), + total_msat=lnaddr2.get_amount_msat(), payment_hash=lnaddr2.paymenthash, min_final_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), payment_secret=lnaddr2.payment_secret, @@ -509,6 +510,7 @@ class TestPeer(ElectrumTestCase): route=route2, chan=bob_channel, amount_msat=lnaddr1.get_amount_msat(), + total_msat=lnaddr1.get_amount_msat(), payment_hash=lnaddr1.paymenthash, min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), payment_secret=lnaddr1.payment_secret, @@ -663,6 +665,7 @@ class TestPeer(ElectrumTestCase): htlc = p1.pay(route=route, chan=alice_channel, amount_msat=lnaddr.get_amount_msat(), + total_msat=lnaddr.get_amount_msat(), payment_hash=lnaddr.paymenthash, min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), payment_secret=lnaddr.payment_secret) @@ -771,7 +774,7 @@ class TestPeer(ElectrumTestCase): min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() payment_hash = lnaddr.paymenthash payment_secret = lnaddr.payment_secret - pay = w1.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry) + pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(PaymentFailure): run(f())