diff --git a/electrum/interface.py b/electrum/interface.py index f1d6f4d9f..1b4e6e229 100644 --- a/electrum/interface.py +++ b/electrum/interface.py @@ -166,7 +166,7 @@ class NotificationSession(RPCSession): try: # note: RPCSession.send_request raises TaskTimeout in case of a timeout. # TaskTimeout is a subclass of CancelledError, which is *suppressed* in TaskGroups - response = await asyncio.wait_for( + response = await util.wait_for2( super().send_request(*args, **kwargs), timeout) except (TaskTimeout, asyncio.TimeoutError) as e: diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 2ba240e97..1fc8ad0df 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -15,6 +15,7 @@ import functools import aiorpcx from aiorpcx import ignore_after +from async_timeout import timeout from .crypto import sha256, sha256d from . import bitcoin, util @@ -331,7 +332,7 @@ class Peer(Logger): async def wait_for_message(self, expected_name: str, channel_id: bytes): q = self.ordered_message_queues[channel_id] - name, payload = await asyncio.wait_for(q.get(), LN_P2P_NETWORK_TIMEOUT) + name, payload = await util.wait_for2(q.get(), LN_P2P_NETWORK_TIMEOUT) # raise exceptions for errors, so that the caller sees them if (err_bytes := payload.get("error")) is not None: err_text = error_text_bytes_to_safe_str(err_bytes) @@ -460,12 +461,12 @@ class Peer(Logger): async def query_gossip(self): try: - await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT) + await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT) except Exception as e: raise GracefulDisconnect(f"Failed to initialize: {e!r}") from e if self.lnworker == self.lnworker.network.lngossip: try: - ids, complete = await asyncio.wait_for(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT) + ids, complete = await util.wait_for2(self.get_channel_range(), LN_P2P_NETWORK_TIMEOUT) except asyncio.TimeoutError as e: raise GracefulDisconnect("query_channel_range timed out") from e self.logger.info('Received {} channel ids. (complete: {})'.format(len(ids), complete)) @@ -575,7 +576,7 @@ class Peer(Logger): async def _message_loop(self): try: - await asyncio.wait_for(self.initialize(), LN_P2P_NETWORK_TIMEOUT) + await util.wait_for2(self.initialize(), LN_P2P_NETWORK_TIMEOUT) except (OSError, asyncio.TimeoutError, HandshakeFailed) as e: raise GracefulDisconnect(f'initialize failed: {repr(e)}') from e async for msg in self.transport.read_messages(): @@ -699,7 +700,7 @@ class Peer(Logger): Channel configurations are initialized in this method. """ # will raise if init fails - await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT) + await util.wait_for2(self.initialized, LN_P2P_NETWORK_TIMEOUT) # trampoline is not yet in features if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey): raise Exception('Not a trampoline node: ' + str(self.their_features)) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index cafb920d8..d2f907b6e 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1071,7 +1071,7 @@ class LNWallet(LNWorker): funding_sat=funding_sat, push_msat=push_sat * 1000, temp_channel_id=os.urandom(32)) - chan, funding_tx = await asyncio.wait_for(coro, LN_P2P_NETWORK_TIMEOUT) + chan, funding_tx = await util.wait_for2(coro, LN_P2P_NETWORK_TIMEOUT) util.trigger_callback('channels_updated', self.wallet) self.wallet.adb.add_transaction(funding_tx) # save tx as local into the wallet self.wallet.sign_transaction(funding_tx, password) diff --git a/electrum/network.py b/electrum/network.py index 47da82b69..8a6aad319 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -811,7 +811,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): # note: using longer timeouts here as DNS can sometimes be slow! timeout = self.get_network_timeout_seconds(NetworkTimeout.Generic) try: - await asyncio.wait_for(interface.ready, timeout) + await util.wait_for2(interface.ready, timeout) except BaseException as e: self.logger.info(f"couldn't launch iface {server} -- {repr(e)}") await interface.close() @@ -1401,7 +1401,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): async def get_response(server: ServerAddr): interface = Interface(network=self, server=server, proxy=self.proxy) try: - await asyncio.wait_for(interface.ready, timeout) + await util.wait_for2(interface.ready, timeout) except BaseException as e: await interface.close() return diff --git a/electrum/plugins/payserver/payserver.py b/electrum/plugins/payserver/payserver.py index b220c3cf5..cfdd6ce42 100644 --- a/electrum/plugins/payserver/payserver.py +++ b/electrum/plugins/payserver/payserver.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Optional from aiohttp import web from aiorpcx import NetAddress +from electrum import util from electrum.util import log_exceptions, ignore_exceptions from electrum.plugin import BasePlugin, hook from electrum.logging import Logger @@ -173,7 +174,7 @@ class PayServer(Logger, EventListener): return ws while True: try: - await asyncio.wait_for(self.pending[key].wait(), 1) + await util.wait_for2(self.pending[key].wait(), 1) break except asyncio.TimeoutError: # send data on the websocket, to keep it alive diff --git a/electrum/scripts/ln_features.py b/electrum/scripts/ln_features.py index 8b6303fae..7b396534d 100644 --- a/electrum/scripts/ln_features.py +++ b/electrum/scripts/ln_features.py @@ -11,7 +11,7 @@ import time from electrum.logging import get_logger, configure_logging from electrum.simple_config import SimpleConfig -from electrum import constants +from electrum import constants, util from electrum.daemon import Daemon from electrum.wallet import create_new_wallet from electrum.util import create_and_start_event_loop, log_exceptions, bfh @@ -84,7 +84,7 @@ async def worker(work_queue: asyncio.Queue, results_queue: asyncio.Queue, flag): print(f"worker connecting to {connect_str}") try: peer = await wallet.lnworker.add_peer(connect_str) - res = await asyncio.wait_for(peer.initialized, TIMEOUT) + res = await util.wait_for2(peer.initialized, TIMEOUT) if res: if peer.features & flag == work['features'] & flag: await results_queue.put(True) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index d2eafa6f2..4a6cd785a 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -824,8 +824,8 @@ class TestPeer(ElectrumTestCase): alice_channel, bob_channel = create_test_channels() p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def pay(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # prep _maybe_send_commitment1 = p1.maybe_send_commitment _maybe_send_commitment2 = p2.maybe_send_commitment @@ -1374,8 +1374,8 @@ class TestPeer(ElectrumTestCase): w2.enable_htlc_settle = False lnaddr, pay_req = self.prepare_invoice(w2) async def pay(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # alice sends htlc route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] p1.pay(route=route, @@ -1401,8 +1401,8 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def action(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) await p1.send_warning(alice_channel.channel_id, 'be warned!', close_connection=True) gath = asyncio.gather(action(), p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) with self.assertRaises(GracefulDisconnect): @@ -1414,8 +1414,8 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def action(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) await p1.send_error(alice_channel.channel_id, 'some error happened!', force_close_channel=True) assert alice_channel.is_closed() gath.cancel() @@ -1447,8 +1447,8 @@ class TestPeer(ElectrumTestCase): async def test(): async def close(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # bob closes channel with different shutdown script await p1.close_channel(alice_channel.channel_id) gath.cancel() @@ -1477,8 +1477,8 @@ class TestPeer(ElectrumTestCase): async def test(): async def close(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) await p1.close_channel(alice_channel.channel_id) gath.cancel() @@ -1538,8 +1538,8 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # peer1 sends known message with trailing garbage # BOLT-01 says peer2 should ignore trailing garbage raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4) + bytes(range(55)) @@ -1570,8 +1570,8 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # peer1 sends unknown 'even-type' message # BOLT-01 says peer2 should close the connection raw_msg2 = (43334).to_bytes(length=2, byteorder="big") + bytes(range(55)) @@ -1600,8 +1600,8 @@ class TestPeer(ElectrumTestCase): p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) async def send_weird_messages(): - await asyncio.wait_for(p1.initialized, 1) - await asyncio.wait_for(p2.initialized, 1) + await util.wait_for2(p1.initialized, 1) + await util.wait_for2(p2.initialized, 1) # peer1 sends known message with insufficient length for the contents # BOLT-01 says peer2 should fail the connection raw_msg1 = encode_msg('ping', num_pong_bytes=4, byteslen=4)[:-1] diff --git a/electrum/util.py b/electrum/util.py index 53fe4fc7c..933b71698 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -24,7 +24,7 @@ import binascii import os, sys, re, json from collections import defaultdict, OrderedDict from typing import (NamedTuple, Union, TYPE_CHECKING, Tuple, Optional, Callable, Any, - Sequence, Dict, Generic, TypeVar, List, Iterable, Set) + Sequence, Dict, Generic, TypeVar, List, Iterable, Set, Awaitable) from datetime import datetime import decimal from decimal import Decimal @@ -1371,6 +1371,36 @@ aiorpcx.curio._set_task_deadline = _aiorpcx_monkeypatched_set_task_deadline aiorpcx.curio._unset_task_deadline = _aiorpcx_monkeypatched_unset_task_deadline +async def wait_for2(fut: Awaitable, timeout: Union[int, float, None]): + """Replacement for asyncio.wait_for, + due to bugs: https://bugs.python.org/issue42130 and https://github.com/python/cpython/issues/86296 , + which are only fixed in python 3.12+. + """ + if sys.version_info[:3] >= (3, 12): + return await asyncio.wait_for(fut, timeout) + else: + async with async_timeout(timeout): + return await asyncio.ensure_future(fut, loop=get_running_loop()) + + +if hasattr(asyncio, 'timeout'): # python 3.11+ + async_timeout = asyncio.timeout +else: + class TimeoutAfterAsynciolike(aiorpcx.curio.TimeoutAfter): + async def __aexit__(self, exc_type, exc_value, traceback): + try: + await super().__aexit__(exc_type, exc_value, traceback) + except (aiorpcx.TaskTimeout, aiorpcx.UncaughtTimeoutError): + raise asyncio.TimeoutError from None + except aiorpcx.TimeoutCancellationError: + raise asyncio.CancelledError from None + + def async_timeout(delay: Union[int, float, None]): + if delay is None: + return nullcontext() + return TimeoutAfterAsynciolike(delay) + + class NetworkJobOnDefaultServer(Logger, ABC): """An abstract base class for a job that runs on the main network interface. Every time the main interface changes, the job is