Browse Source

lnonion: add total_msat to onion payment data

master
bitromortac 5 years ago
parent
commit
9f84fa9580
No known key found for this signature in database
GPG Key ID: 1965063FC13BEBE2
  1. 10
      electrum/lnonion.py
  2. 11
      electrum/lnpeer.py
  3. 8
      electrum/lnworker.py
  4. 5
      electrum/tests/test_lnpeer.py

10
electrum/lnonion.py

@ -261,7 +261,7 @@ def new_onion_packet(payment_path_pubkeys: Sequence[bytes], session_key: bytes,
hmac=next_hmac) hmac=next_hmac)
def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int, total_msat: int,
final_cltv: int, *, payment_secret: bytes = None) \ final_cltv: int, *, payment_secret: bytes = None) \
-> Tuple[List[OnionHopsDataSingle], int, int]: -> Tuple[List[OnionHopsDataSingle], int, int]:
"""Returns the hops_data to be used for constructing an onion packet, """Returns the hops_data to be used for constructing an onion packet,
@ -277,8 +277,14 @@ def calc_hops_data_for_payment(route: 'LNPaymentRoute', amount_msat: int,
"amt_to_forward": {"amt_to_forward": amt}, "amt_to_forward": {"amt_to_forward": amt},
"outgoing_cltv_value": {"outgoing_cltv_value": cltv}, "outgoing_cltv_value": {"outgoing_cltv_value": cltv},
} }
# for multipart payments we need to tell the reciever about the total and
# partial amounts
if payment_secret is not None: if payment_secret is not None:
hop_payload["payment_data"] = {"payment_secret": payment_secret, "total_msat": amt} hop_payload["payment_data"] = {
"payment_secret": payment_secret,
"total_msat": total_msat,
"amount_msat": amt
}
hops_data = [OnionHopsDataSingle(is_tlv_payload=route[-1].has_feature_varonion(), hops_data = [OnionHopsDataSingle(is_tlv_payload=route[-1].has_feature_varonion(),
payload=hop_payload)] payload=hop_payload)]
# payloads, backwards from last hop (but excluding the first edge): # payloads, backwards from last hop (but excluding the first edge):

11
electrum/lnpeer.py

@ -1196,7 +1196,7 @@ class Peer(Logger):
self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs)) self.send_message("commitment_signed", channel_id=chan.channel_id, signature=sig_64, num_htlcs=len(htlc_sigs), htlc_signature=b"".join(htlc_sigs))
def pay(self, *, route: 'LNPaymentRoute', chan: Channel, amount_msat: int, def pay(self, *, route: 'LNPaymentRoute', chan: Channel, amount_msat: int,
payment_hash: bytes, min_final_cltv_expiry: int, total_msat: int, payment_hash: bytes, min_final_cltv_expiry: int,
payment_secret: bytes = None, fwd_trampoline_onion=None) -> UpdateAddHtlc: payment_secret: bytes = None, fwd_trampoline_onion=None) -> UpdateAddHtlc:
assert amount_msat > 0, "amount_msat is not greater zero" assert amount_msat > 0, "amount_msat is not greater zero"
assert len(route) > 0 assert len(route) > 0
@ -1206,8 +1206,13 @@ class Peer(Logger):
route[0].node_features |= self.features route[0].node_features |= self.features
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
final_cltv = local_height + min_final_cltv_expiry final_cltv = local_height + min_final_cltv_expiry
hops_data, amount_msat, cltv = calc_hops_data_for_payment(route, amount_msat, final_cltv, hops_data, amount_msat, cltv = calc_hops_data_for_payment(
payment_secret=payment_secret) route,
amount_msat,
total_msat,
final_cltv,
payment_secret=payment_secret
)
self.logger.info(f"lnpeer.pay len(route)={len(route)}") self.logger.info(f"lnpeer.pay len(route)={len(route)}")
for i in range(len(route)): for i in range(len(route)):
self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}") self.logger.info(f" {i}: edge={route[i].short_channel_id} hop_data={hops_data[i]!r}")

8
electrum/lnworker.py

@ -1090,7 +1090,7 @@ class LNWallet(LNWorker):
routes = [(route, amount_to_send)] routes = [(route, amount_to_send)]
# 2. send htlcs # 2. send htlcs
for route, amount_msat in routes: for route, amount_msat in routes:
await self.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion) await self.pay_to_route(route, amount_msat, amount_to_pay, payment_hash, payment_secret, min_cltv_expiry, trampoline_onion)
amount_inflight += amount_msat amount_inflight += amount_msat
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex()) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex())
# 3. await a queue # 3. await a queue
@ -1106,8 +1106,9 @@ class LNWallet(LNWorker):
# if we get a channel update, we might retry the same route and amount # if we get a channel update, we might retry the same route and amount
self.handle_error_code_from_failed_htlc(htlc_log) self.handle_error_code_from_failed_htlc(htlc_log)
async def pay_to_route(self, route: LNPaymentRoute, amount_msat: int,
async def pay_to_route(self, route: LNPaymentRoute, amount_msat:int, payment_hash:bytes, payment_secret:bytes, min_cltv_expiry:int, trampoline_onion:bytes =None): total_msat: int, payment_hash: bytes, payment_secret: bytes,
min_cltv_expiry: int, trampoline_onion: bytes=None):
# send a single htlc # send a single htlc
short_channel_id = route[0].short_channel_id short_channel_id = route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id) chan = self.get_channel_by_short_id(short_channel_id)
@ -1119,6 +1120,7 @@ class LNWallet(LNWorker):
route=route, route=route,
chan=chan, chan=chan,
amount_msat=amount_msat, amount_msat=amount_msat,
total_msat=total_msat,
payment_hash=payment_hash, payment_hash=payment_hash,
min_final_cltv_expiry=min_cltv_expiry, min_final_cltv_expiry=min_cltv_expiry,
payment_secret=payment_secret, payment_secret=payment_secret,

5
electrum/tests/test_lnpeer.py

@ -497,6 +497,7 @@ class TestPeer(ElectrumTestCase):
route=route1, route=route1,
chan=alice_channel, chan=alice_channel,
amount_msat=lnaddr2.get_amount_msat(), amount_msat=lnaddr2.get_amount_msat(),
total_msat=lnaddr2.get_amount_msat(),
payment_hash=lnaddr2.paymenthash, payment_hash=lnaddr2.paymenthash,
min_final_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), min_final_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
payment_secret=lnaddr2.payment_secret, payment_secret=lnaddr2.payment_secret,
@ -509,6 +510,7 @@ class TestPeer(ElectrumTestCase):
route=route2, route=route2,
chan=bob_channel, chan=bob_channel,
amount_msat=lnaddr1.get_amount_msat(), amount_msat=lnaddr1.get_amount_msat(),
total_msat=lnaddr1.get_amount_msat(),
payment_hash=lnaddr1.paymenthash, payment_hash=lnaddr1.paymenthash,
min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), min_final_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
payment_secret=lnaddr1.payment_secret, payment_secret=lnaddr1.payment_secret,
@ -663,6 +665,7 @@ class TestPeer(ElectrumTestCase):
htlc = p1.pay(route=route, htlc = p1.pay(route=route,
chan=alice_channel, chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(), amount_msat=lnaddr.get_amount_msat(),
total_msat=lnaddr.get_amount_msat(),
payment_hash=lnaddr.paymenthash, payment_hash=lnaddr.paymenthash,
min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
payment_secret=lnaddr.payment_secret) payment_secret=lnaddr.payment_secret)
@ -771,7 +774,7 @@ class TestPeer(ElectrumTestCase):
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() min_cltv_expiry = lnaddr.get_min_final_cltv_expiry()
payment_hash = lnaddr.paymenthash payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret payment_secret = lnaddr.payment_secret
pay = w1.pay_to_route(route, amount_msat, payment_hash, payment_secret, min_cltv_expiry) pay = w1.pay_to_route(route, amount_msat, amount_msat, payment_hash, payment_secret, min_cltv_expiry)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
run(f()) run(f())

Loading…
Cancel
Save