From da67fda92aff829c1c99f6505f7802b023d061ba Mon Sep 17 00:00:00 2001 From: ThomasV Date: Sat, 29 Feb 2020 09:12:33 +0100 Subject: [PATCH] HTLC switch: - fulfills/fails/forwards HTLCs - onion_packets are saved when update_add_htlc is received --- electrum/lnchannel.py | 7 +++++- electrum/lnhtlc.py | 1 + electrum/lnpeer.py | 44 +++++++++-------------------------- electrum/lnworker.py | 42 ++++++++++++++++++++++++++++++++- electrum/tests/test_lnpeer.py | 12 +++++----- 5 files changed, 65 insertions(+), 41 deletions(-) diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index a83dcb028..bc72a4442 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -408,7 +408,7 @@ class Channel(Logger): self.logger.info("add_htlc") return htlc - def receive_htlc(self, htlc: UpdateAddHtlc) -> UpdateAddHtlc: + def receive_htlc(self, htlc: UpdateAddHtlc, onion_packet:bytes = None) -> UpdateAddHtlc: """ ReceiveHTLC adds an HTLC to the state machine's remote update log. This method should be called in response to receiving a new HTLC from the remote @@ -427,6 +427,11 @@ class Channel(Logger): f' HTLC amount: {htlc.amount_msat}') with self.db_lock: self.hm.recv_htlc(htlc) + local_ctn = self.get_latest_ctn(LOCAL) + remote_ctn = self.get_latest_ctn(REMOTE) + if onion_packet: + self.hm.log['unfulfilled_htlcs'][htlc.htlc_id] = local_ctn, remote_ctn, onion_packet.hex() + self.logger.info("receive_htlc") return htlc diff --git a/electrum/lnhtlc.py b/electrum/lnhtlc.py index 01e502f5a..4dc20c508 100644 --- a/electrum/lnhtlc.py +++ b/electrum/lnhtlc.py @@ -25,6 +25,7 @@ class HTLCManager: log[LOCAL] = deepcopy(initial) log[REMOTE] = deepcopy(initial) log['unacked_local_updates2'] = {} + log['unfulfilled_htlcs'] = {} # htlc_id -> onion_packet # maybe bootstrap fee_updates if initial_feerate was provided if initial_feerate is not None: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index bd6032ad3..f216b919d 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -1131,45 +1131,25 @@ class Peer(Logger): self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") cltv_expiry = int.from_bytes(payload["cltv_expiry"], 'big') amount_msat_htlc = int.from_bytes(payload["amount_msat"], 'big') - onion_packet = OnionPacket.from_bytes(payload["onion_routing_packet"]) - processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=self.privkey) + onion_packet = payload["onion_routing_packet"] if chan.get_state() != channel_states.OPEN: raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()}") if cltv_expiry > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: asyncio.ensure_future(self.lnworker.force_close_channel(channel_id)) raise RemoteMisbehaving(f"received update_add_htlc with cltv_expiry > BLOCKHEIGHT_MAX. value was {cltv_expiry}") # add htlc - htlc = UpdateAddHtlc(amount_msat=amount_msat_htlc, - payment_hash=payment_hash, - cltv_expiry=cltv_expiry, - timestamp=int(time.time()), - htlc_id=htlc_id) - htlc = chan.receive_htlc(htlc) - # TODO: fulfilling/failing/forwarding of htlcs should be robust to going offline. - # instead of storing state implicitly in coroutines, we could decouple it from receiving the htlc. - # maybe persist the required details, and have a long-running task that makes these decisions. - local_ctn = chan.get_latest_ctn(LOCAL) - remote_ctn = chan.get_latest_ctn(REMOTE) - if processed_onion.are_we_final: - asyncio.ensure_future(self._maybe_fulfill_htlc(chan=chan, - htlc=htlc, - local_ctn=local_ctn, - remote_ctn=remote_ctn, - onion_packet=onion_packet, - processed_onion=processed_onion)) - else: - asyncio.ensure_future(self._maybe_forward_htlc(chan=chan, - htlc=htlc, - local_ctn=local_ctn, - remote_ctn=remote_ctn, - onion_packet=onion_packet, - processed_onion=processed_onion)) + htlc = UpdateAddHtlc( + amount_msat=amount_msat_htlc, + payment_hash=payment_hash, + cltv_expiry=cltv_expiry, + timestamp=int(time.time()), + htlc_id=htlc_id) + chan.receive_htlc(htlc, onion_packet) + @log_exceptions - async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, + async def _maybe_forward_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): - await self.await_local(chan, local_ctn) - await self.await_remote(chan, remote_ctn) # Forward HTLC # FIXME: this is not robust to us going offline before payment is fulfilled # FIXME: there are critical safety checks MISSING here @@ -1250,10 +1230,8 @@ class Peer(Logger): await self.fail_htlc(chan, htlc.htlc_id, onion_packet, reason) @log_exceptions - async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, local_ctn: int, remote_ctn: int, + async def _maybe_fulfill_htlc(self, chan: Channel, htlc: UpdateAddHtlc, *, onion_packet: OnionPacket, processed_onion: ProcessedOnionPacket): - await self.await_local(chan, local_ctn) - await self.await_remote(chan, remote_ctn) try: info = self.lnworker.get_payment_info(htlc.payment_hash) preimage = self.lnworker.get_preimage(htlc.payment_hash) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index ac33aa6b6..53a0ef220 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -55,7 +55,7 @@ from .lnutil import (Outpoint, LNPeerAddr, ShortChannelID, PaymentAttemptLog, PaymentAttemptFailureDetails) from .lnutil import ln_dummy_address, ln_compare_features from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput -from .lnonion import OnionFailureCode +from .lnonion import OnionFailureCode, process_onion_packet, OnionPacket from .lnmsg import decode_msg from .i18n import _ from .lnrouter import RouteEdge, LNPaymentRoute, is_route_sane_to_use @@ -470,6 +470,7 @@ class LNWallet(LNWorker): self.reestablish_peers_and_channels(), self.sync_with_local_watchtower(), self.sync_with_remote_watchtower(), + self.htlc_switch(), ]: tg_coro = self.taskgroup.spawn(coro) asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop) @@ -1325,3 +1326,42 @@ class LNWallet(LNWorker): if feerate_per_kvbyte is None: feerate_per_kvbyte = FEERATE_FALLBACK_STATIC_FEE return max(253, feerate_per_kvbyte // 4) + + @log_exceptions + async def htlc_switch(self): + while True: + await asyncio.sleep(0.1) + for chan_id, chan in self.channels.items(): + if chan.peer_state != peer_states.GOOD: + continue + peer = self.peers[chan.node_id] + done = set() + unfulfilled = chan.hm.log['unfulfilled_htlcs'] + for htlc_id, (local_ctn, remote_ctn, onion_packet_hex) in unfulfilled.items(): + # todo: decouple this from processing. + await peer.await_local(chan, local_ctn) + await peer.await_remote(chan, remote_ctn) + # + chan.logger.info(f'found unfulfilled htlc: {htlc_id}') + onion_packet = OnionPacket.from_bytes(bytes.fromhex(onion_packet_hex)) + htlc = chan.hm.log[REMOTE]['adds'][htlc_id] + payment_hash = htlc.payment_hash + processed_onion = process_onion_packet(onion_packet, associated_data=payment_hash, our_onion_private_key=peer.privkey) + if processed_onion.are_we_final: + await peer._maybe_fulfill_htlc( + chan=chan, + htlc=htlc, + onion_packet=onion_packet, + processed_onion=processed_onion) + else: + # todo: if we are forwarding we need to test next peer's state + # we should dissociate forwarding and fulfillment + await peer._maybe_forward_htlc( + chan=chan, + htlc=htlc, + onion_packet=onion_packet, + processed_onion=processed_onion) + done.add(htlc_id) + # cleanup + for htlc_id in done: + unfulfilled.pop(htlc_id) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 2e3d65ba9..6acba2c25 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -238,7 +238,7 @@ class TestPeer(ElectrumTestCase): self.assertEqual(alice_channel.peer_state, peer_states.GOOD) self.assertEqual(bob_channel.peer_state, peer_states.GOOD) gath.cancel() - gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) + gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) async def f(): await gath with self.assertRaises(concurrent.futures.CancelledError): @@ -253,7 +253,7 @@ class TestPeer(ElectrumTestCase): result = await LNWallet._pay(w1, pay_req) self.assertEqual(result, True) gath.cancel() - gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) + gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) async def f(): await gath with self.assertRaises(concurrent.futures.CancelledError): @@ -271,7 +271,7 @@ class TestPeer(ElectrumTestCase): # wait so that pending messages are processed #await asyncio.sleep(1) gath.cancel() - gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) + gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) async def f(): await gath with self.assertRaises(concurrent.futures.CancelledError): @@ -285,7 +285,7 @@ class TestPeer(ElectrumTestCase): result = await LNWallet._pay(w1, pay_req) self.assertTrue(result) gath.cancel() - gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) + gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) async def f(): await gath with self.assertRaises(concurrent.futures.CancelledError): @@ -313,7 +313,7 @@ class TestPeer(ElectrumTestCase): async def set_settle(): await asyncio.sleep(0.1) w2.enable_htlc_settle.set() - gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop()) + gath = asyncio.gather(pay(), set_settle(), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) async def f(): await gath with self.assertRaises(concurrent.futures.CancelledError): @@ -338,7 +338,7 @@ class TestPeer(ElectrumTestCase): # AssertionError is ok since we shouldn't use old routes, and the # route finding should fail when channel is closed async def f(): - await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop()) + await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop(), LNWallet.htlc_switch(w1), LNWallet.htlc_switch(w2)) with self.assertRaises(PaymentFailure): run(f())