Browse Source

HTLC switch:

- fulfills/fails/forwards HTLCs
- onion_packets are saved when update_add_htlc is received
master
ThomasV 6 years ago
parent
commit
da67fda92a
  1. 7
      electrum/lnchannel.py
  2. 1
      electrum/lnhtlc.py
  3. 44
      electrum/lnpeer.py
  4. 42
      electrum/lnworker.py
  5. 12
      electrum/tests/test_lnpeer.py

7
electrum/lnchannel.py

@ -408,7 +408,7 @@ class Channel(Logger):
self.logger.info("add_htlc") self.logger.info("add_htlc")
return htlc return htlc
def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: def receive_htlc(self, htlc: UpdateAddHtlc, onion_packet:bytes = None) -> UpdateAddHtlc:
""" """
ReceiveHTLC adds an HTLC to the state machine's remote update log. This ReceiveHTLC adds an HTLC to the state machine's remote update log. This
method should be called in response to receiving a new HTLC from the remote method should be called in response to receiving a new HTLC from the remote
@ -427,6 +427,11 @@ class Channel(Logger):
f' HTLC amount: {htlc.amount_msat}') f' HTLC amount: {htlc.amount_msat}')
with self.db_lock: with self.db_lock:
self.hm.recv_htlc(htlc) self.hm.recv_htlc(htlc)
local_ctn = self.get_latest_ctn(LOCAL)
remote_ctn = self.get_latest_ctn(REMOTE)
if onion_packet:
self.hm.log['unfulfilled_htlcs'][htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex()
self.logger.info("receive_htlc") self.logger.info("receive_htlc")
return htlc return htlc

1
electrum/lnhtlc.py

@ -25,6 +25,7 @@ class HTLCManager:
log[LOCAL] = deepcopy(initial) log[LOCAL] = deepcopy(initial)
log[REMOTE] = deepcopy(initial) log[REMOTE] = deepcopy(initial)
log['unacked_local_updates2'] = {} log['unacked_local_updates2'] = {}
log['unfulfilled_htlcs'] = {} # htlc_id -> onion_packet
# maybe bootstrap fee_updates if initial_feerate was provided # maybe bootstrap fee_updates if initial_feerate was provided
if initial_feerate is not None: if initial_feerate is not None:

44
electrum/lnpeer.py

@ -1131,45 +1131,25 @@ class Peer(Logger):
self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big') cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big')
amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big') amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big')
onion_packet = OnionPacket.from_bytes(payload["onion_routing_packet"]) onion_packet = payload["onion_routing_packet"]
processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey)
if chan.get_state() != channel_states.OPEN: if chan.get_state() != channel_states.OPEN:
raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}") raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}")
if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX:
asyncio.ensure_future(self.lnworker.force_close_channel(channel_id)) asyncio.ensure_future(self.lnworker.force_close_channel(channel_id))
raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}") raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}")
# add htlc # add htlc
htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, htlc = UpdateAddHtlc(
payment_hash=payment_hash, amount_msat=amount_msat_htlc,
cltv_expiry=cltv_expiry, payment_hash=payment_hash,
timestamp=int(time.time()), cltv_expiry=cltv_expiry,
htlc_id=htlc_id) timestamp=int(time.time()),
htlc = chan.receive_htlc(htlc) htlc_id=htlc_id)
# TODO: fulfilling/failing/forwarding of htlcs should be robust to going offline. chan.receive_htlc(htlc, onion_packet)
# instead of storing state implicitly in coroutines, we could decouple it from receiving the htlc.
# maybe persist the required details, and have a long-running task that makes these decisions.
local_ctn = chan.get_latest_ctn(LOCAL)
remote_ctn = chan.get_latest_ctn(REMOTE)
if processed_onion.are_we_final:
asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))
else:
asyncio.ensure_future(self._maybe_forward_htlc(chan=chan,
htlc=htlc,
local_ctn=local_ctn,
remote_ctn=remote_ctn,
onion_packet=onion_packet,
processed_onion=processed_onion))
@log_exceptions @log_exceptions
async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
# Forward HTLC # Forward HTLC
# FIXME: this is not robust to us going offline before payment is fulfilled # FIXME: this is not robust to us going offline before payment is fulfilled
# FIXME: there are critical safety checks MISSING here # FIXME: there are critical safety checks MISSING here
@ -1250,10 +1230,8 @@ class Peer(Logger):
await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason)
@log_exceptions @log_exceptions
async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *,
onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket):
await self.await_local(chan, local_ctn)
await self.await_remote(chan, remote_ctn)
try: try:
info = self.lnworker.get_payment_info(htlc.payment_hash) info = self.lnworker.get_payment_info(htlc.payment_hash)
preimage = self.lnworker.get_preimage(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash)

42
electrum/lnworker.py

@ -55,7 +55,7 @@ from .lnutil import (Outpoint, LNPeerAddr,
ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails) ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails)
from .lnutil import ln_dummy_address, ln_compare_features from .lnutil import ln_dummy_address, ln_compare_features
from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
from .lnonion import OnionFailureCode from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket
from .lnmsg import decode_msg from .lnmsg import decode_msg
from .i18n import _ from .i18n import _
from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use
@ -470,6 +470,7 @@ class LNWallet(LNWorker):
self.reestablish_peers_and_channels(), self.reestablish_peers_and_channels(),
self.sync_with_local_watchtower(), self.sync_with_local_watchtower(),
self.sync_with_remote_watchtower(), self.sync_with_remote_watchtower(),
self.htlc_switch(),
]: ]:
tg_coro = self.taskgroup.spawn(coro) tg_coro = self.taskgroup.spawn(coro)
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
@ -1325,3 +1326,42 @@ class LNWallet(LNWorker):
if feerate_per_kvbyte is None: if feerate_per_kvbyte is None:
feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE
return max(253, feerate_per_kvbyte // 4) return max(253, feerate_per_kvbyte // 4)
@log_exceptions
async def htlc_switch(self):
while True:
await asyncio.sleep(0.1)
for chan_id, chan in self.channels.items():
if chan.peer_state != peer_states.GOOD:
continue
peer = self.peers[chan.node_id]
done = set()
unfulfilled = chan.hm.log['unfulfilled_htlcs']
for htlc_id, (local_ctn, remote_ctn, onion_packet_hex) in unfulfilled.items():
# todo: decouple this from processing.
await peer.await_local(chan, local_ctn)
await peer.await_remote(chan, remote_ctn)
#
chan.logger.info(f'found unfulfilled htlc: {htlc_id}')
onion_packet = OnionPacket.from_bytes(bytes.fromhex(onion_packet_hex))
htlc = chan.hm.log[REMOTE]['adds'][htlc_id]
payment_hash = htlc.payment_hash
processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=peer.privkey)
if processed_onion.are_we_final:
await peer._maybe_fulfill_htlc(
chan=chan,
htlc=htlc,
onion_packet=onion_packet,
processed_onion=processed_onion)
else:
# todo: if we are forwarding we need to test next peer's state
# we should dissociate forwarding and fulfillment
await peer._maybe_forward_htlc(
chan=chan,
htlc=htlc,
onion_packet=onion_packet,
processed_onion=processed_onion)
done.add(htlc_id)
# cleanup
for htlc_id in done:
unfulfilled.pop(htlc_id)

12
electrum/tests/test_lnpeer.py

@ -238,7 +238,7 @@ class TestPeer(ElectrumTestCase):
self.assertEqual(alice_channel.peer_state, peer_states.GOOD) self.assertEqual(alice_channel.peer_state, peer_states.GOOD)
self.assertEqual(bob_channel.peer_state, peer_states.GOOD) self.assertEqual(bob_channel.peer_state, peer_states.GOOD)
gath.cancel() gath.cancel()
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -253,7 +253,7 @@ class TestPeer(ElectrumTestCase):
result = await LNWallet._pay(w1, pay_req) result = await LNWallet._pay(w1, pay_req)
self.assertEqual(result, True) self.assertEqual(result, True)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -271,7 +271,7 @@ class TestPeer(ElectrumTestCase):
# wait so that pending messages are processed # wait so that pending messages are processed
#await asyncio.sleep(1) #await asyncio.sleep(1)
gath.cancel() gath.cancel()
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -285,7 +285,7 @@ class TestPeer(ElectrumTestCase):
result = await LNWallet._pay(w1, pay_req) result = await LNWallet._pay(w1, pay_req)
self.assertTrue(result) self.assertTrue(result)
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -313,7 +313,7 @@ class TestPeer(ElectrumTestCase):
async def set_settle(): async def set_settle():
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
w2.enable_htlc_settle.set() w2.enable_htlc_settle.set()
gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
async def f(): async def f():
await gath await gath
with self.assertRaises(concurrent.futures.CancelledError): with self.assertRaises(concurrent.futures.CancelledError):
@ -338,7 +338,7 @@ class TestPeer(ElectrumTestCase):
# AssertionError is ok since we shouldn't use old routes, and the # AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed # route finding should fail when channel is closed
async def f(): async def f():
await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop()) await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2))
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
run(f()) run(f())

Loading…
Cancel
Save