Browse Source

lnpeer: fix timing issue in reest_chan, for replaying unacked updates

We must not process incoming updates for a given channel until we ~finished reestablishing it.

Consider both parties have some unacked updates they want to replay during reestablish.
If Bob reacts to Alice's replayed stuff before he himself replays his stuff, madness ensues.

I think this should fix the remaining part of https://github.com/spesmilo/electrum/pull/8778
(timing issues when running the unit tests with py3.12)
master
SomberNight 2 years ago
parent
commit
79d88dcb5f
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 31
      electrum/lnpeer.py
  2. 7
      electrum/tests/__init__.py
  3. 8
      electrum/tests/test_lnpeer.py

31
electrum/lnpeer.py

@ -109,7 +109,8 @@ class Peer(Logger):
self.funding_created_sent = set() # for channels in PREOPENING self.funding_created_sent = set() # for channels in PREOPENING
self.funding_signed_sent = set() # for channels in PREOPENING self.funding_signed_sent = set() # for channels in PREOPENING
self.shutdown_received = {} # chan_id -> asyncio.Future() self.shutdown_received = {} # chan_id -> asyncio.Future()
self.channel_reestablish_msg = defaultdict(self.asyncio_loop.create_future) self.channel_reestablish_msg = defaultdict(self.asyncio_loop.create_future) # type: Dict[bytes, asyncio.Future]
self._chan_reest_finished = defaultdict(asyncio.Event) # type: Dict[bytes, asyncio.Event]
self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict] self.orphan_channel_updates = OrderedDict() # type: OrderedDict[ShortChannelID, dict]
Logger.__init__(self) Logger.__init__(self)
self.taskgroup = OldTaskGroup() self.taskgroup = OldTaskGroup()
@ -1161,7 +1162,7 @@ class Peer(Logger):
self.logger.info(f"tried to force-close channel {chan.get_id_for_log()} " self.logger.info(f"tried to force-close channel {chan.get_id_for_log()} "
f"but close option is not allowed. {chan.get_state()=!r}") f"but close option is not allowed. {chan.get_state()=!r}")
def on_channel_reestablish(self, chan: Channel, msg): async def on_channel_reestablish(self, chan: Channel, msg):
# Note: it is critical for this message handler to block processing of further messages, # Note: it is critical for this message handler to block processing of further messages,
# until this msg is processed. If we are behind (lost state), and send chan_reest to the remote, # until this msg is processed. If we are behind (lost state), and send chan_reest to the remote,
# when the remote realizes we are behind, they might send an "error" message - but the spec mandates # when the remote realizes we are behind, they might send an "error" message - but the spec mandates
@ -1259,6 +1260,11 @@ class Peer(Logger):
else: else:
# all good # all good
fut.set_result((we_must_resend_revoke_and_ack, their_next_local_ctn)) fut.set_result((we_must_resend_revoke_and_ack, their_next_local_ctn))
# Block processing of further incoming messages until we finished our part of chan-reest.
# This is needed for the replaying of our local unacked updates to be sane (if the peer
# also replays some messages we must not react to them until we finished replaying our own).
# (it would be sufficient to only block messages related to this channel, but this is easier)
await self._chan_reest_finished[chan.channel_id].wait()
def _send_channel_reestablish(self, chan: Channel): def _send_channel_reestablish(self, chan: Channel):
assert self.is_initialized() assert self.is_initialized()
@ -1376,6 +1382,7 @@ class Peer(Logger):
resend_revoke_and_ack() resend_revoke_and_ack()
chan.peer_state = PeerState.GOOD chan.peer_state = PeerState.GOOD
self._chan_reest_finished[chan.channel_id].set()
if chan.is_funded(): if chan.is_funded():
chan_just_became_ready = (their_next_local_ctn == next_local_ctn == 1) chan_just_became_ready = (their_next_local_ctn == next_local_ctn == 1)
if chan_just_became_ready or self.features.supports(LnFeatures.OPTION_SCID_ALIAS_OPT): if chan_just_became_ready or self.features.supports(LnFeatures.OPTION_SCID_ALIAS_OPT):
@ -1386,7 +1393,7 @@ class Peer(Logger):
util.trigger_callback('channel', self.lnworker.wallet, chan) util.trigger_callback('channel', self.lnworker.wallet, chan)
# if we have sent a previous shutdown, it must be retransmitted (Bolt2) # if we have sent a previous shutdown, it must be retransmitted (Bolt2)
if chan.get_state() == ChannelState.SHUTDOWN: if chan.get_state() == ChannelState.SHUTDOWN:
await self.send_shutdown(chan) await self.taskgroup.spawn(self.send_shutdown(chan))
def send_channel_ready(self, chan: Channel): def send_channel_ready(self, chan: Channel):
assert chan.is_funded() assert chan.is_funded()
@ -1413,6 +1420,8 @@ class Peer(Logger):
def on_channel_ready(self, chan: Channel, payload): def on_channel_ready(self, chan: Channel, payload):
self.logger.info(f"on_channel_ready. channel: {chan.channel_id.hex()}") self.logger.info(f"on_channel_ready. channel: {chan.channel_id.hex()}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received channel_ready in unexpected {chan.peer_state=!r}")
# save remote alias for use in invoices # save remote alias for use in invoices
scid_alias = payload.get('channel_ready_tlvs', {}).get('short_channel_id', {}).get('alias') scid_alias = payload.get('channel_ready_tlvs', {}).get('short_channel_id', {}).get('alias')
if scid_alias: if scid_alias:
@ -1526,6 +1535,8 @@ class Peer(Logger):
htlc_id = payload["id"] htlc_id = payload["id"]
reason = payload["reason"] reason = payload["reason"]
self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fail_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fail_htlc in unexpected {chan.peer_state=!r}")
chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) chan.receive_fail_htlc(htlc_id, error_bytes=reason) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
@ -1671,6 +1682,8 @@ class Peer(Logger):
if chan.peer_state == PeerState.BAD: if chan.peer_state == PeerState.BAD:
return return
self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.') self.logger.info(f'on_commitment_signed. chan {chan.short_channel_id}. ctn: {chan.get_next_ctn(LOCAL)}.')
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received commitment_signed in unexpected {chan.peer_state=!r}")
# make sure there were changes to the ctx, otherwise the remote peer is misbehaving # make sure there were changes to the ctx, otherwise the remote peer is misbehaving
if not chan.has_pending_changes(LOCAL): if not chan.has_pending_changes(LOCAL):
# TODO if feerate changed A->B->A; so there were updates but the value is identical, # TODO if feerate changed A->B->A; so there were updates but the value is identical,
@ -1692,6 +1705,8 @@ class Peer(Logger):
payment_hash = sha256(preimage) payment_hash = sha256(preimage)
htlc_id = payload["id"] htlc_id = payload["id"]
self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}") self.logger.info(f"on_update_fulfill_htlc. chan {chan.short_channel_id}. htlc_id {htlc_id}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fulfill_htlc in unexpected {chan.peer_state=!r}")
chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id) chan.receive_htlc_settle(preimage, htlc_id) # TODO handle exc and maybe fail channel (e.g. bad htlc_id)
self.lnworker.save_preimage(payment_hash, preimage) self.lnworker.save_preimage(payment_hash, preimage)
self.maybe_send_commitment(chan) self.maybe_send_commitment(chan)
@ -1701,6 +1716,8 @@ class Peer(Logger):
failure_code = payload["failure_code"] failure_code = payload["failure_code"]
self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. " self.logger.info(f"on_update_fail_malformed_htlc. chan {chan.get_id_for_log()}. "
f"htlc_id {htlc_id}. failure_code={failure_code}") f"htlc_id {htlc_id}. failure_code={failure_code}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fail_malformed_htlc in unexpected {chan.peer_state=!r}")
if failure_code & OnionFailureCodeMetaFlag.BADONION == 0: if failure_code & OnionFailureCodeMetaFlag.BADONION == 0:
self.schedule_force_closing(chan.channel_id) self.schedule_force_closing(chan.channel_id)
raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}") raise RemoteMisbehaving(f"received update_fail_malformed_htlc with unexpected failure code: {failure_code}")
@ -1723,6 +1740,8 @@ class Peer(Logger):
self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc={str(htlc)}") self.logger.info(f"on_update_add_htlc. chan {chan.short_channel_id}. htlc={str(htlc)}")
if chan.get_state() != ChannelState.OPEN: if chan.get_state() != ChannelState.OPEN:
raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()!r}") raise RemoteMisbehaving(f"received update_add_htlc while chan.get_state() != OPEN. state was {chan.get_state()!r}")
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_add_htlc in unexpected {chan.peer_state=!r}")
if cltv_abs > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX: if cltv_abs > bitcoin.NLOCKTIME_BLOCKHEIGHT_MAX:
self.schedule_force_closing(chan.channel_id) self.schedule_force_closing(chan.channel_id)
raise RemoteMisbehaving(f"received update_add_htlc with {cltv_abs=} > BLOCKHEIGHT_MAX") raise RemoteMisbehaving(f"received update_add_htlc with {cltv_abs=} > BLOCKHEIGHT_MAX")
@ -2216,6 +2235,8 @@ class Peer(Logger):
if chan.peer_state == PeerState.BAD: if chan.peer_state == PeerState.BAD:
return return
self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}') self.logger.info(f'on_revoke_and_ack. chan {chan.short_channel_id}. ctn: {chan.get_oldest_unrevoked_ctn(REMOTE)}')
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received revoke_and_ack in unexpected {chan.peer_state=!r}")
rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"]) rev = RevokeAndAck(payload["per_commitment_secret"], payload["next_per_commitment_point"])
chan.receive_revocation(rev) chan.receive_revocation(rev)
self.lnworker.save_channel(chan) self.lnworker.save_channel(chan)
@ -2224,6 +2245,8 @@ class Peer(Logger):
self._received_revack_event.clear() self._received_revack_event.clear()
def on_update_fee(self, chan: Channel, payload): def on_update_fee(self, chan: Channel, payload):
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received update_fee in unexpected {chan.peer_state=!r}")
feerate = payload["feerate_per_kw"] feerate = payload["feerate_per_kw"]
chan.update_fee(feerate, False) chan.update_fee(feerate, False)
@ -2297,6 +2320,8 @@ class Peer(Logger):
# TODO: A receiving node: if it hasn't received a funding_signed (if it is a # TODO: A receiving node: if it hasn't received a funding_signed (if it is a
# funder) or a funding_created (if it is a fundee): # funder) or a funding_created (if it is a fundee):
# SHOULD send an error and fail the channel. # SHOULD send an error and fail the channel.
if chan.peer_state != PeerState.GOOD: # should never happen
raise Exception(f"received shutdown in unexpected {chan.peer_state=!r}")
their_scriptpubkey = payload['scriptpubkey'] their_scriptpubkey = payload['scriptpubkey']
their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script their_upfront_scriptpubkey = chan.config[REMOTE].upfront_shutdown_script
# BOLT-02 check if they use the upfront shutdown script they advertized # BOLT-02 check if they use the upfront shutdown script they advertized

7
electrum/tests/__init__.py

@ -9,6 +9,7 @@ import electrum
import electrum.logging import electrum.logging
from electrum import constants from electrum import constants
from electrum import util from electrum import util
from electrum.logging import Logger
# Set this locally to make the test suite run faster. # Set this locally to make the test suite run faster.
@ -23,7 +24,7 @@ electrum.logging._configure_stderr_logging()
electrum.util.AS_LIB_USER_I_WANT_TO_MANAGE_MY_OWN_ASYNCIO_LOOP = True electrum.util.AS_LIB_USER_I_WANT_TO_MANAGE_MY_OWN_ASYNCIO_LOOP = True
class ElectrumTestCase(unittest.IsolatedAsyncioTestCase): class ElectrumTestCase(unittest.IsolatedAsyncioTestCase, Logger):
"""Base class for our unit tests.""" """Base class for our unit tests."""
TESTNET = False TESTNET = False
@ -32,6 +33,10 @@ class ElectrumTestCase(unittest.IsolatedAsyncioTestCase):
# some unit tests are modifying globals... so we run sequentially: # some unit tests are modifying globals... so we run sequentially:
_test_lock = threading.Lock() _test_lock = threading.Lock()
def __init__(self, *args, **kwargs):
Logger.__init__(self)
unittest.IsolatedAsyncioTestCase.__init__(self, *args, **kwargs)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()

8
electrum/tests/test_lnpeer.py

@ -660,6 +660,7 @@ class TestPeerDirect(TestPeer):
await p1.received_commitsig_event.wait() await p1.received_commitsig_event.wait()
await group.cancel_remaining() await group.cancel_remaining()
# simulating disconnection. recreate transports. # simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
for chan in (chan_AB, chan_BA): for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED chan.peer_state = PeerState.DISCONNECTED
@ -676,7 +677,8 @@ class TestPeerDirect(TestPeer):
"replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output))
self.assertEqual(chan_AB.peer_state, PeerState.GOOD) self.assertEqual(chan_AB.peer_state, PeerState.GOOD)
self.assertEqual(chan_BA.peer_state, PeerState.GOOD) self.assertEqual(chan_BA.peer_state, PeerState.GOOD)
raise SuccessfulTest() await group.cancel_remaining()
raise SuccessfulTest()
with self.assertRaises(SuccessfulTest): with self.assertRaises(SuccessfulTest):
await f() await f()
@ -715,6 +717,7 @@ class TestPeerDirect(TestPeer):
await p1.received_commitsig_event.wait() await p1.received_commitsig_event.wait()
await group.cancel_remaining() await group.cancel_remaining()
# simulating disconnection. recreate transports. # simulating disconnection. recreate transports.
self.logger.info("simulating disconnection. recreating transports.")
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2) p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(chan_AB, chan_BA, k1=k1, k2=k2)
for chan in (chan_AB, chan_BA): for chan in (chan_AB, chan_BA):
chan.peer_state = PeerState.DISCONNECTED chan.peer_state = PeerState.DISCONNECTED
@ -731,7 +734,8 @@ class TestPeerDirect(TestPeer):
"replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output)) "replayed 2 unacked messages. ['update_add_htlc', 'commitment_signed']" in msg) for msg in logs.output))
self.assertEqual(chan_AB.peer_state, PeerState.GOOD) self.assertEqual(chan_AB.peer_state, PeerState.GOOD)
self.assertEqual(chan_BA.peer_state, PeerState.GOOD) self.assertEqual(chan_BA.peer_state, PeerState.GOOD)
raise SuccessfulTest() await group.cancel_remaining()
raise SuccessfulTest()
with self.assertRaises(SuccessfulTest): with self.assertRaises(SuccessfulTest):
await f() await f()

Loading…
Cancel
Save