diff --git a/electrum/lnmsg.py b/electrum/lnmsg.py index 766a05046..06a0eec46 100644 --- a/electrum/lnmsg.py +++ b/electrum/lnmsg.py @@ -9,9 +9,11 @@ from .lnutil import OnionFailureCodeMetaFlag class FailedToParseMsg(Exception): pass -class MalformedMsg(FailedToParseMsg): pass class UnknownMsgType(FailedToParseMsg): pass +class UnknownOptionalMsgType(UnknownMsgType): pass +class UnknownMandatoryMsgType(UnknownMsgType): pass +class MalformedMsg(FailedToParseMsg): pass class UnknownMsgFieldType(MalformedMsg): pass class UnexpectedEndOfStream(MalformedMsg): pass class FieldEncodingNotMinimal(MalformedMsg): pass @@ -479,7 +481,10 @@ class LNSerializer: try: scheme = self.msg_scheme_from_type[msg_type_bytes] except KeyError: - raise UnknownMsgType(f"msg_type={msg_type_int}") # TODO even/odd type? + if msg_type_int % 2 == 0: # even types must be understood: "mandatory" + raise UnknownMandatoryMsgType(f"msg_type={msg_type_int}") + else: # odd types are ok not to understand: "optional" + raise UnknownOptionalMsgType(f"msg_type={msg_type_int}") assert scheme[0][2] == msg_type_int msg_type_name = scheme[0][1] parsed = {} diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index f9a3be142..e7f889a30 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -46,7 +46,7 @@ from .lnutil import (Outpoint, LocalConfig, RECEIVED, UpdateAddHtlc, UpfrontShutdownScriptViolation) from .lnutil import FeeUpdate, channel_id_from_funding_tx from .lntransport import LNTransport, LNTransportBase -from .lnmsg import encode_msg, decode_msg +from .lnmsg import encode_msg, decode_msg, UnknownOptionalMsgType from .interface import GracefulDisconnect from .lnrouter import fee_for_edge_msat from .lnutil import ln_dummy_address @@ -179,7 +179,11 @@ class Peer(Logger): self.ping_time = time.time() def process_message(self, message): - message_type, payload = decode_msg(message) + try: + message_type, payload = decode_msg(message) + except UnknownOptionalMsgType as e: + self.logger.info(f"received unknown message from peer. ignoring: {e!r}") + return # only process INIT if we are a backup if self.is_channel_backup is True and message_type != 'init': return diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 49f71386b..031e64c06 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -29,6 +29,7 @@ from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.channel_db import ChannelDB from electrum.lnworker import LNWallet, NoPathFound from electrum.lnmsg import encode_msg, decode_msg +from electrum import lnmsg from electrum.logging import console_stderr_handler, Logger from electrum.lnworker import PaymentInfo, RECEIVED from electrum.lnonion import OnionFailureCode @@ -1086,6 +1087,95 @@ class TestPeer(TestCaseForTestnet): with self.assertRaises(PaymentFailure): run(f()) + @needs_test_with_all_chacha20_implementations + def test_sending_weird_messages_that_should_be_ignored(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def send_weird_messages(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + # peer1 sends known message with trailing garbage + # BOLT-01 says peer2 should ignore trailing garbage + raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55)) + p1.transport.send_bytes(raw_msg1) + await asyncio.sleep(0.05) + # peer1 sends unknown 'odd-type' message + # BOLT-01 says peer2 should ignore whole message + raw_msg2 = (43333).to_bytes(length=2, byteorder="big") + bytes(range(55)) + p1.transport.send_bytes(raw_msg2) + await asyncio.sleep(0.05) + raise TestSuccess() + + async def f(): + async with TaskGroup() as group: + for peer in [p1, p2]: + await group.spawn(peer._message_loop()) + await group.spawn(peer.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(send_weird_messages()) + + with self.assertRaises(TestSuccess): + run(f()) + + @needs_test_with_all_chacha20_implementations + def test_sending_weird_messages__unknown_even_type(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def send_weird_messages(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + # peer1 sends unknown 'even-type' message + # BOLT-01 says peer2 should close the connection + raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55)) + p1.transport.send_bytes(raw_msg2) + await asyncio.sleep(0.05) + + failing_task = None + async def f(): + nonlocal failing_task + async with TaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + failing_task = await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(send_weird_messages()) + + with self.assertRaises(lnmsg.UnknownMandatoryMsgType): + run(f()) + self.assertTrue(isinstance(failing_task.exception(), lnmsg.UnknownMandatoryMsgType)) + + @needs_test_with_all_chacha20_implementations + def test_sending_weird_messages__known_msg_with_insufficient_length(self): + alice_channel, bob_channel = create_test_channels() + p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) + + async def send_weird_messages(): + await asyncio.wait_for(p1.initialized, 1) + await asyncio.wait_for(p2.initialized, 1) + # peer1 sends known message with insufficient length for the contents + # BOLT-01 says peer2 should fail the connection + raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1] + p1.transport.send_bytes(raw_msg1) + await asyncio.sleep(0.05) + + failing_task = None + async def f(): + nonlocal failing_task + async with TaskGroup() as group: + await group.spawn(p1._message_loop()) + await group.spawn(p1.htlc_switch()) + failing_task = await group.spawn(p2._message_loop()) + await group.spawn(p2.htlc_switch()) + await asyncio.sleep(0.2) + await group.spawn(send_weird_messages()) + + with self.assertRaises(lnmsg.UnexpectedEndOfStream): + run(f()) + self.assertTrue(isinstance(failing_task.exception(), lnmsg.UnexpectedEndOfStream)) + def run(coro): return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()