From 6557a21c45a8915edc74bb729a627b3505989d0d Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 30 Aug 2023 11:49:42 +0000 Subject: [PATCH] channel_db: don't wait for load_data to finish if stopping ChannelDB.load_data() takes ~15 seconds. Previously if the user tried to close the program while load_data is running, we would block until load_data() finished. (e.g. consider starting and immediately stopping Electrum) Now instead we can abort load_data early. --- electrum/channel_db.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/electrum/channel_db.py b/electrum/channel_db.py index 9eb9abd61..f4579b071 100644 --- a/electrum/channel_db.py +++ b/electrum/channel_db.py @@ -33,6 +33,7 @@ import base64 import asyncio import threading from enum import IntEnum +import functools from aiorpcx import NetAddress @@ -273,6 +274,9 @@ def get_mychannel_policy(short_channel_id: bytes, node_id: bytes, return Policy.from_msg(local_update_decoded) +class _LoadDataAborted(Exception): pass + + create_channel_info = """ CREATE TABLE IF NOT EXISTS channel_info ( short_channel_id BLOB(8), @@ -733,15 +737,30 @@ class ChannelDB(SqlDB): return [(str(net_addr.host), net_addr.port, ts) for net_addr, ts in addr_to_ts.items()] + def handle_abort(func): + @functools.wraps(func) + def wrapper(self: 'ChannelDB', *args, **kwargs): + try: + return func(self, *args, **kwargs) + except _LoadDataAborted: + return + return wrapper + @sql @profiler + @handle_abort def load_data(self): if self.data_loaded.is_set(): return # Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow. + def maybe_abort(): + if self.stopping: + self.logger.info("load_data() was asked to stop. exiting early.") + raise _LoadDataAborted() c = self.conn.cursor() c.execute("""SELECT * FROM address""") for x in c: + maybe_abort() node_id, host, port, timestamp = x try: net_addr = NetAddress(host, port) @@ -757,6 +776,7 @@ class ChannelDB(SqlDB): self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS] c.execute("""SELECT * FROM channel_info""") for short_channel_id, msg in c: + maybe_abort() try: ci = ChannelInfo.from_raw_msg(msg) except IncompatibleOrInsaneFeatures: @@ -766,6 +786,7 @@ class ChannelDB(SqlDB): self._channels[ShortChannelID.normalize(short_channel_id)] = ci c.execute("""SELECT * FROM node_info""") for node_id, msg in c: + maybe_abort() try: node_info, node_addresses = NodeInfo.from_raw_msg(msg) except IncompatibleOrInsaneFeatures: @@ -776,6 +797,7 @@ class ChannelDB(SqlDB): self._nodes[node_id] = node_info c.execute("""SELECT * FROM policy""") for key, msg in c: + maybe_abort() try: p = Policy.from_raw_msg(key, msg) except FailedToParseMsg: