diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 3f90bfcbf..36f84a2be 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -109,7 +109,8 @@ class Peer(Logger): self.funding_created_sent = set() # for channels in PREOPENING self.funding_signed_sent = set() # for channels in PREOPENING self.shutdown_received = {} # chan_id -> asyncio.Future() - self.channel_reestablish_msg = defaultdict(self.asyncio_loop.create_future) + self.channel_reestablish_msg = defaultdict(self.asyncio_loop.create_future) # type: Dict[bytes, asyncio.Future] + self._chan_reest_finished = defaultdict(asyncio.Event) # type: Dict[bytes, asyncio.Event] self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict] Logger.__init__(self) self.taskgroup = OldTaskGroup() @@ -1161,7 +1162,7 @@ class Peer(Logger): self.logger.info(f"tried to force-close channel {chan.get_id_for_log()} " f"but close option is not allowed. {chan.get_state()=!r}") - def on_channel_reestablish(self, chan: Channel, msg): + async def on_channel_reestablish(self, chan: Channel, msg): # Note: it is critical for this message handler to block processing of further messages, # until this msg is processed. If we are behind (lost state), and send chan_reest to the remote, # when the remote realizes we are behind, they might send an "error" message - but the spec mandates @@ -1259,6 +1260,11 @@ class Peer(Logger): else: # all good fut.set_result((we_must_resend_revoke_and_ack, their_next_local_ctn)) + # Block processing of further incoming messages until we finished our part of chan-reest. + # This is needed for the replaying of our local unacked updates to be sane (if the peer + # also replays some messages we must not react to them until we finished replaying our own). + # (it would be sufficient to only block messages related to this channel, but this is easier) + await self._chan_reest_finished[chan.channel_id].wait() def _send_channel_reestablish(self, chan: Channel): assert self.is_initialized() @@ -1376,6 +1382,7 @@ class Peer(Logger): resend_revoke_and_ack() chan.peer_state = PeerState.GOOD + self._chan_reest_finished[chan.channel_id].set() if chan.is_funded(): chan_just_became_ready = (their_next_local_ctn == next_local_ctn == 1) if chan_just_became_ready or self.features.supports(LnFeatures.OPTION_SCID_ALIAS_OPT): @@ -1386,7 +1393,7 @@ class Peer(Logger): util.trigger_callback('channel', self.lnworker.wallet, chan) # if we have sent a previous shutdown, it must be retransmitted (Bolt2) if chan.get_state() == ChannelState.SHUTDOWN: - await self.send_shutdown(chan) + await self.taskgroup.spawn(self.send_shutdown(chan)) def send_channel_ready(self, chan: Channel): assert chan.is_funded() @@ -1413,6 +1420,8 @@ class Peer(Logger): def on_channel_ready(self, chan: Channel, payload): self.logger.info(f"on_channel_ready. channel: {chan.channel_id.hex()}") + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received channel_ready in unexpected {chan.peer_state=!r}") # save remote alias for use in invoices scid_alias = payload.get('channel_ready_tlvs', {}).get('short_channel_id', {}).get('alias') if scid_alias: @@ -1526,6 +1535,8 @@ class Peer(Logger): htlc_id = payload["id"] reason = payload["reason"] self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received update_fail_htlc in unexpected {chan.peer_state=!r}") chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) self.maybe_send_commitment(chan) @@ -1671,6 +1682,8 @@ class Peer(Logger): if chan.peer_state == PeerState.BAD: return self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.') + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received commitment_signed in unexpected {chan.peer_state=!r}") # make sure there were changes to the ctx, otherwise the remote peer is misbehaving if not chan.has_pending_changes(LOCAL): # TODO if feerate changed A->B->A; so there were updates but the value is identical, @@ -1692,6 +1705,8 @@ class Peer(Logger): payment_hash = sha256(preimage) htlc_id = payload["id"] self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received update_fulfill_htlc in unexpected {chan.peer_state=!r}") chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) self.lnworker.save_preimage(payment_hash, preimage) self.maybe_send_commitment(chan) @@ -1701,6 +1716,8 @@ class Peer(Logger): failure_code = payload["failure_code"] self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. " f"htlc_id {htlc_id}. failure_code={failure_code}") + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received update_fail_malformed_htlc in unexpected {chan.peer_state=!r}") if failure_code & OnionFailureCodeMetaFlag.BADONION == 0: self.schedule_force_closing(chan.channel_id) raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}") @@ -1723,6 +1740,8 @@ class Peer(Logger): self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc={str(htlc)}") if chan.get_state() != ChannelState.OPEN: raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()!r}") + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received update_add_htlc in unexpected {chan.peer_state=!r}") if cltv_abs > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: self.schedule_force_closing(chan.channel_id) raise RemoteMisbehaving(f"received update_add_htlc with {cltv_abs=} > BLOCKHEIGHT_MAX") @@ -2216,6 +2235,8 @@ class Peer(Logger): if chan.peer_state == PeerState.BAD: return self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}') + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received revoke_and_ack in unexpected {chan.peer_state=!r}") rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]) chan.receive_revocation(rev) self.lnworker.save_channel(chan) @@ -2224,6 +2245,8 @@ class Peer(Logger): self._received_revack_event.clear() def on_update_fee(self, chan: Channel, payload): + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received update_fee in unexpected {chan.peer_state=!r}") feerate = payload["feerate_per_kw"] chan.update_fee(feerate, False) @@ -2297,6 +2320,8 @@ class Peer(Logger): # TODO: A receiving node: if it hasn't received a funding_signed (if it is a # funder) or a funding_created (if it is a fundee): # SHOULD send an error and fail the channel. + if chan.peer_state != PeerState.GOOD: # should never happen + raise Exception(f"received shutdown in unexpected {chan.peer_state=!r}") their_scriptpubkey = payload['scriptpubkey'] their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script # BOLT-02 check if they use the upfront shutdown script they advertized diff --git a/electrum/tests/__init__.py b/electrum/tests/__init__.py index 64210a976..8737da940 100644 --- a/electrum/tests/__init__.py +++ b/electrum/tests/__init__.py @@ -9,6 +9,7 @@ import electrum import electrum.logging from electrum import constants from electrum import util +from electrum.logging import Logger # Set this locally to make the test suite run faster. @@ -23,7 +24,7 @@ electrum.logging._configure_stderr_logging() electrum.util.AS_LIB_USER_I_WANT_TO_MANAGE_MY_OWN_ASYNCIO_LOOP = True -class ElectrumTestCase(unittest.IsolatedAsyncioTestCase): +class ElectrumTestCase(unittest.IsolatedAsyncioTestCase, Logger): """Base class for our unit tests.""" TESTNET = False @@ -32,6 +33,10 @@ class ElectrumTestCase(unittest.IsolatedAsyncioTestCase): # some unit tests are modifying globals... so we run sequentially: _test_lock = threading.Lock() + def __init__(self, *args, **kwargs): + Logger.__init__(self) + unittest.IsolatedAsyncioTestCase.__init__(self, *args, **kwargs) + @classmethod def setUpClass(cls): super().setUpClass() diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 8d04474a1..b2b77bb4b 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -660,6 +660,7 @@ class TestPeerDirect(TestPeer): await p1.received_commitsig_event.wait() await group.cancel_remaining() # simulating disconnection. recreate transports. + self.logger.info("simulating disconnection. recreating transports.") p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED @@ -676,7 +677,8 @@ class TestPeerDirect(TestPeer): "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) self.assertEqual(chan_AB.peer_state, PeerState.GOOD) self.assertEqual(chan_BA.peer_state, PeerState.GOOD) - raise SuccessfulTest() + await group.cancel_remaining() + raise SuccessfulTest() with self.assertRaises(SuccessfulTest): await f() @@ -715,6 +717,7 @@ class TestPeerDirect(TestPeer): await p1.received_commitsig_event.wait() await group.cancel_remaining() # simulating disconnection. recreate transports. + self.logger.info("simulating disconnection. recreating transports.") p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) for chan in (chan_AB, chan_BA): chan.peer_state = PeerState.DISCONNECTED @@ -731,7 +734,8 @@ class TestPeerDirect(TestPeer): "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) self.assertEqual(chan_AB.peer_state, PeerState.GOOD) self.assertEqual(chan_BA.peer_state, PeerState.GOOD) - raise SuccessfulTest() + await group.cancel_remaining() + raise SuccessfulTest() with self.assertRaises(SuccessfulTest): await f()