Browse Source

channel_db: don't wait for load_data to finish if stopping

ChannelDB.load_data() takes ~15 seconds. Previously if the user tried
to close the program while load_data is running, we would block until
load_data() finished. (e.g. consider starting and immediately stopping
Electrum)
Now instead we can abort load_data early.
master
SomberNight 2 years ago
parent
commit
6557a21c45
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 22
      electrum/channel_db.py

22
electrum/channel_db.py

@ -33,6 +33,7 @@ import base64
import asyncio
import threading
from enum import IntEnum
import functools
from aiorpcx import NetAddress
@ -273,6 +274,9 @@ def get_mychannel_policy(short_channel_id: bytes, node_id: bytes,
return Policy.from_msg(local_update_decoded)
class _LoadDataAborted(Exception): pass
create_channel_info = """
CREATE TABLE IF NOT EXISTS channel_info (
short_channel_id BLOB(8),
@ -733,15 +737,30 @@ class ChannelDB(SqlDB):
return [(str(net_addr.host), net_addr.port, ts)
for net_addr, ts in addr_to_ts.items()]
def handle_abort(func):
@functools.wraps(func)
def wrapper(self: 'ChannelDB', *args, **kwargs):
try:
return func(self, *args, **kwargs)
except _LoadDataAborted:
return
return wrapper
@sql
@profiler
@handle_abort
def load_data(self):
if self.data_loaded.is_set():
return
# Note: this method takes several seconds... mostly due to lnmsg.decode_msg being slow.
def maybe_abort():
if self.stopping:
self.logger.info("load_data() was asked to stop. exiting early.")
raise _LoadDataAborted()
c = self.conn.cursor()
c.execute("""SELECT * FROM address""")
for x in c:
maybe_abort()
node_id, host, port, timestamp = x
try:
net_addr = NetAddress(host, port)
@ -757,6 +776,7 @@ class ChannelDB(SqlDB):
self._recent_peers = sorted_node_ids[:self.NUM_MAX_RECENT_PEERS]
c.execute("""SELECT * FROM channel_info""")
for short_channel_id, msg in c:
maybe_abort()
try:
ci = ChannelInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
@ -766,6 +786,7 @@ class ChannelDB(SqlDB):
self._channels[ShortChannelID.normalize(short_channel_id)] = ci
c.execute("""SELECT * FROM node_info""")
for node_id, msg in c:
maybe_abort()
try:
node_info, node_addresses = NodeInfo.from_raw_msg(msg)
except IncompatibleOrInsaneFeatures:
@ -776,6 +797,7 @@ class ChannelDB(SqlDB):
self._nodes[node_id] = node_info
c.execute("""SELECT * FROM policy""")
for key, msg in c:
maybe_abort()
try:
p = Policy.from_raw_msg(key, msg)
except FailedToParseMsg:

Loading…
Cancel
Save