Browse Source

Merge pull request #8493 from spesmilo/jsonpatch_new

partial-writes using jsonpatch
master
ThomasV 2 years ago committed by GitHub
parent
commit
1af6972d03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      contrib/requirements/requirements.txt
  2. 2
      electrum/gui/qt/console.py
  3. 3
      electrum/gui/qt/main_window.py
  4. 147
      electrum/json_db.py
  5. 3
      electrum/lnpeer.py
  6. 22
      electrum/storage.py
  7. 4
      electrum/wallet.py

1
contrib/requirements/requirements.txt

@ -7,6 +7,7 @@ aiohttp_socks>=0.3
certifi certifi
bitstring bitstring
attrs>=20.1.0 attrs>=20.1.0
jsonpatch
# Note that we also need the dnspython[DNSSEC] extra which pulls in cryptography, # Note that we also need the dnspython[DNSSEC] extra which pulls in cryptography,
# but as that is not pure-python it cannot be listed in this file! # but as that is not pure-python it cannot be listed in this file!

2
electrum/gui/qt/console.py

@ -187,6 +187,8 @@ class Console(QtWidgets.QPlainTextEdit):
return return
if command and (not self.history or self.history[-1] != command): if command and (not self.history or self.history[-1] != command):
while len(self.history) >= 50:
self.history.remove(self.history[0])
self.history.append(command) self.history.append(command)
self.history_index = len(self.history) self.history_index = len(self.history)

3
electrum/gui/qt/main_window.py

@ -1550,7 +1550,7 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger, QtEventListener):
def update_console(self): def update_console(self):
console = self.console console = self.console
console.history = self.wallet.db.get("qt-console-history", []) console.history = self.wallet.db.get_stored_item("qt-console-history", [])
console.history_index = len(console.history) console.history_index = len(console.history)
console.updateNamespace({ console.updateNamespace({
@ -2620,7 +2620,6 @@ class ElectrumWindow(QMainWindow, MessageBoxMixin, Logger, QtEventListener):
g = self.geometry() g = self.geometry()
self.wallet.db.put("winpos-qt", [g.left(),g.top(), self.wallet.db.put("winpos-qt", [g.left(),g.top(),
g.width(),g.height()]) g.width(),g.height()])
self.wallet.db.put("qt-console-history", self.console.history[-50:])
if self.qr_window: if self.qr_window:
self.qr_window.close() self.qr_window.close()
self.close_wallet() self.close_wallet()

147
electrum/json_db.py

@ -26,6 +26,7 @@ import threading
import copy import copy
import json import json
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import jsonpatch
from . import util from . import util
from .util import WalletFileException, profiler from .util import WalletFileException, profiler
@ -80,22 +81,35 @@ def stored_in(name, _type=dict):
return decorator return decorator
def key_path(path, key):
def to_str(x):
if isinstance(x, int):
return str(int(x))
else:
assert isinstance(x, str)
return x
return '/' + '/'.join([to_str(x) for x in path + [to_str(key)]])
class StoredObject: class StoredObject:
db = None db = None
path = None
def __setattr__(self, key, value): def __setattr__(self, key, value):
if self.db: if self.db and key not in ['path', 'db'] and not key.startswith('_'):
self.db.set_modified(True) if value != getattr(self, key):
self.db.add_patch({'op': 'replace', 'path': key_path(self.path, key), 'value': value})
object.__setattr__(self, key, value) object.__setattr__(self, key, value)
def set_db(self, db): def set_db(self, db, path):
self.db = db self.db = db
self.path = path
def to_json(self): def to_json(self):
d = dict(vars(self)) d = dict(vars(self))
d.pop('db', None) d.pop('db', None)
d.pop('path', None)
# don't expose/store private stuff # don't expose/store private stuff
d = {k: v for k, v in d.items() d = {k: v for k, v in d.items()
if not k.startswith('_')} if not k.startswith('_')}
@ -112,20 +126,22 @@ class StoredDict(dict):
self.path = path self.path = path
# recursively convert dicts to StoredDict # recursively convert dicts to StoredDict
for k, v in list(data.items()): for k, v in list(data.items()):
self.__setitem__(k, v) self.__setitem__(k, v, patch=False)
@locked @locked
def __setitem__(self, key, v): def __setitem__(self, key, v, patch=True):
is_new = key not in self is_new = key not in self
# early return to prevent unnecessary disk writes # early return to prevent unnecessary disk writes
if not is_new and self[key] == v: if not is_new and patch:
return if self.db and json.dumps(v, cls=self.db.encoder) == json.dumps(self[key], cls=self.db.encoder):
return
# recursively set db and path # recursively set db and path
if isinstance(v, StoredDict): if isinstance(v, StoredDict):
#assert v.db is None
v.db = self.db v.db = self.db
v.path = self.path + [key] v.path = self.path + [key]
for k, vv in v.items(): for k, vv in v.items():
v[k] = vv v.__setitem__(k, vv, patch=False)
# recursively convert dict to StoredDict. # recursively convert dict to StoredDict.
# _convert_dict is called breadth-first # _convert_dict is called breadth-first
elif isinstance(v, dict): elif isinstance(v, dict):
@ -139,29 +155,57 @@ class StoredDict(dict):
v = self.db._convert_value(self.path, key, v) v = self.db._convert_value(self.path, key, v)
# set parent of StoredObject # set parent of StoredObject
if isinstance(v, StoredObject): if isinstance(v, StoredObject):
v.set_db(self.db) v.set_db(self.db, self.path + [key])
# convert lists
if isinstance(v, list):
v = StoredList(v, self.db, self.path + [key])
# set item # set item
dict.__setitem__(self, key, v) dict.__setitem__(self, key, v)
if self.db: if self.db and patch:
self.db.set_modified(True) op = 'add' if is_new else 'replace'
self.db.add_patch({'op': op, 'path': key_path(self.path, key), 'value': v})
@locked @locked
def __delitem__(self, key): def __delitem__(self, key):
dict.__delitem__(self, key) dict.__delitem__(self, key)
if self.db: if self.db:
self.db.set_modified(True) self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
@locked @locked
def pop(self, key, v=_RaiseKeyError): def pop(self, key, v=_RaiseKeyError):
if v is _RaiseKeyError: if key not in self:
r = dict.pop(self, key) if v is _RaiseKeyError:
else: raise KeyError(key)
r = dict.pop(self, key, v) else:
return v
r = dict.pop(self, key)
if self.db: if self.db:
self.db.set_modified(True) self.db.add_patch({'op': 'remove', 'path': key_path(self.path, key)})
return r return r
class StoredList(list):
def __init__(self, data, db, path):
list.__init__(self, data)
self.db = db
self.lock = self.db.lock if self.db else threading.RLock()
self.path = path
@locked
def append(self, item):
n = len(self)
list.append(self, item)
if self.db:
self.db.add_patch({'op': 'add', 'path': key_path(self.path, '%d'%n), 'value':item})
@locked
def remove(self, item):
n = self.index(item)
list.remove(self, item)
if self.db:
self.db.add_patch({'op': 'remove', 'path': key_path(self.path, '%d'%n)})
class JsonDB(Logger): class JsonDB(Logger):
@ -171,34 +215,41 @@ class JsonDB(Logger):
self.lock = threading.RLock() self.lock = threading.RLock()
self.storage = storage self.storage = storage
self.encoder = encoder self.encoder = encoder
self.pending_changes = []
self._modified = False self._modified = False
# load data # load data
data = self.load_data(s) data = self.load_data(s)
if upgrader: if upgrader:
data, was_upgraded = upgrader(data) data, was_upgraded = upgrader(data)
else: self._modified |= was_upgraded
was_upgraded = False
# convert to StoredDict # convert to StoredDict
self.data = StoredDict(data, self, []) self.data = StoredDict(data, self, [])
# note: self._modified may have been affected by StoredDict
self._modified = was_upgraded
# write file in case there was a db upgrade # write file in case there was a db upgrade
if self.storage and self.storage.file_exists(): if self.storage and self.storage.file_exists():
self.write() self._write()
def load_data(self, s:str) -> dict: def load_data(self, s:str) -> dict:
""" overloaded in wallet_db """ """ overloaded in wallet_db """
if s == '': if s == '':
return {} return {}
try: try:
data = json.loads(s) data = json.loads('[' + s + ']')
data, patches = data[0], data[1:]
except Exception: except Exception:
if r := self.maybe_load_ast_data(s): if r := self.maybe_load_ast_data(s):
data = r data, patches = r, []
elif r := self.maybe_load_incomplete_data(s):
data, patches = r, []
else: else:
raise WalletFileException("Cannot read wallet file. (parsing failed)") raise WalletFileException("Cannot read wallet file. (parsing failed)")
if not isinstance(data, dict): if not isinstance(data, dict):
raise WalletFileException("Malformed wallet file (not dict)") raise WalletFileException("Malformed wallet file (not dict)")
if patches:
# apply patches
self.logger.info('found %d patches'%len(patches))
patch = jsonpatch.JsonPatch(patches)
data = patch.apply(data)
self.set_modified(True)
return data return data
def maybe_load_ast_data(self, s): def maybe_load_ast_data(self, s):
@ -220,6 +271,21 @@ class JsonDB(Logger):
data[key] = value data[key] = value
return data return data
def maybe_load_incomplete_data(self, s):
n = s.count('{') - s.count('}')
i = len(s)
while n > 0 and i > 0:
i = i - 1
if s[i] == '{':
n = n - 1
if s[i] == '}':
n = n + 1
if n == 0:
s = s[0:i]
assert s[-2:] == ',\n'
self.logger.info('found incomplete data {s[i:]}')
return self.load_data(s[0:-2])
def set_modified(self, b): def set_modified(self, b):
with self.lock: with self.lock:
self._modified = b self._modified = b
@ -227,6 +293,11 @@ class JsonDB(Logger):
def modified(self): def modified(self):
return self._modified return self._modified
@locked
def add_patch(self, patch):
self.pending_changes.append(json.dumps(patch, cls=self.encoder))
self.set_modified(True)
@locked @locked
def get(self, key, default=None): def get(self, key, default=None):
v = self.data.get(key) v = self.data.get(key)
@ -259,6 +330,12 @@ class JsonDB(Logger):
self.data[name] = {} self.data[name] = {}
return self.data[name] return self.data[name]
@locked
def get_stored_item(self, key, default) -> dict:
if key not in self.data:
self.data[key] = default
return self.data[key]
@locked @locked
def dump(self, *, human_readable: bool = True) -> str: def dump(self, *, human_readable: bool = True) -> str:
"""Serializes the DB as a string. """Serializes the DB as a string.
@ -302,10 +379,29 @@ class JsonDB(Logger):
v = constructor(v) v = constructor(v)
return v return v
@locked
def write(self): def write(self):
with self.lock: if not self.storage.file_exists()\
or self.storage.is_encrypted()\
or self.storage.needs_consolidation():
self._write() self._write()
else:
self._append_pending_changes()
@locked
def _append_pending_changes(self):
if threading.current_thread().daemon:
self.logger.warning('daemon thread cannot write db')
return
if not self.pending_changes:
self.logger.info('no pending changes')
return
self.logger.info(f'appending {len(self.pending_changes)} pending changes')
s = ''.join([',\n' + x for x in self.pending_changes])
self.storage.append(s)
self.pending_changes = []
@locked
@profiler @profiler
def _write(self): def _write(self):
if threading.current_thread().daemon: if threading.current_thread().daemon:
@ -315,4 +411,5 @@ class JsonDB(Logger):
return return
json_str = self.dump(human_readable=not self.storage.is_encrypted()) json_str = self.dump(human_readable=not self.storage.is_encrypted())
self.storage.write(json_str) self.storage.write(json_str)
self.pending_changes = []
self.set_modified(False) self.set_modified(False)

3
electrum/lnpeer.py

@ -895,7 +895,8 @@ class Peer(Logger):
"revocation_store": {}, "revocation_store": {},
"channel_type": channel_type, "channel_type": channel_type,
} }
return StoredDict(chan_dict, self.lnworker.db if self.lnworker else None, []) # set db to None, because we do not want to write updates until channel is saved
return StoredDict(chan_dict, None, [])
async def on_open_channel(self, payload): async def on_open_channel(self, payload):
"""Implements the channel acceptance flow. """Implements the channel acceptance flow.

22
electrum/storage.py

@ -71,10 +71,14 @@ class WalletStorage(Logger):
if self.file_exists(): if self.file_exists():
with open(self.path, "r", encoding='utf-8') as f: with open(self.path, "r", encoding='utf-8') as f:
self.raw = f.read() self.raw = f.read()
self.pos = f.seek(0, os.SEEK_END)
self.init_pos = self.pos
self._encryption_version = self._init_encryption_version() self._encryption_version = self._init_encryption_version()
else: else:
self.raw = '' self.raw = ''
self._encryption_version = StorageEncryptionVersion.PLAINTEXT self._encryption_version = StorageEncryptionVersion.PLAINTEXT
self.pos = 0
self.init_pos = 0
def read(self): def read(self):
return self.decrypted if self.is_encrypted() else self.raw return self.decrypted if self.is_encrypted() else self.raw
@ -83,15 +87,13 @@ class WalletStorage(Logger):
s = self.encrypt_before_writing(data) s = self.encrypt_before_writing(data)
temp_path = "%s.tmp.%s" % (self.path, os.getpid()) temp_path = "%s.tmp.%s" % (self.path, os.getpid())
with open(temp_path, "w", encoding='utf-8') as f: with open(temp_path, "w", encoding='utf-8') as f:
f.write(s) self.pos = f.write(s)
f.flush() f.flush()
os.fsync(f.fileno()) os.fsync(f.fileno())
try: try:
mode = os.stat(self.path).st_mode mode = os.stat(self.path).st_mode
except FileNotFoundError: except FileNotFoundError:
mode = stat.S_IREAD | stat.S_IWRITE mode = stat.S_IREAD | stat.S_IWRITE
# assert that wallet file does not exist, to prevent wallet corruption (see issue #5082) # assert that wallet file does not exist, to prevent wallet corruption (see issue #5082)
if not self.file_exists(): if not self.file_exists():
assert not os.path.exists(self.path) assert not os.path.exists(self.path)
@ -100,6 +102,19 @@ class WalletStorage(Logger):
self._file_exists = True self._file_exists = True
self.logger.info(f"saved {self.path}") self.logger.info(f"saved {self.path}")
def append(self, data: str) -> None:
""" append data to file. for the moment, only non-encrypted file"""
assert not self.is_encrypted()
with open(self.path, "r+", encoding='utf-8') as f:
pos = f.seek(0, os.SEEK_END)
assert pos == self.pos, (self.pos, pos)
self.pos += f.write(data)
f.flush()
os.fsync(f.fileno())
def needs_consolidation(self):
return self.pos > 2 * self.init_pos
def file_exists(self) -> bool: def file_exists(self) -> bool:
return self._file_exists return self._file_exists
@ -179,6 +194,7 @@ class WalletStorage(Logger):
def encrypt_before_writing(self, plaintext: str) -> str: def encrypt_before_writing(self, plaintext: str) -> str:
s = plaintext s = plaintext
if self.pubkey: if self.pubkey:
self.decrypted = plaintext
s = bytes(s, 'utf8') s = bytes(s, 'utf8')
c = zlib.compress(s, level=zlib.Z_BEST_SPEED) c = zlib.compress(s, level=zlib.Z_BEST_SPEED)
enc_magic = self._get_encryption_magic() enc_magic = self._get_encryption_magic()

4
electrum/wallet.py

@ -2828,7 +2828,9 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self._update_password_for_keystore(old_pw, new_pw) self._update_password_for_keystore(old_pw, new_pw)
encrypt_keystore = self.can_have_keystore_encryption() encrypt_keystore = self.can_have_keystore_encryption()
self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore) self.db.set_keystore_encryption(bool(new_pw) and encrypt_keystore)
self.save_db() ## save changes
if self.storage and self.storage.file_exists():
self.db._write()
self.unlock(None) self.unlock(None)
@abstractmethod @abstractmethod

Loading…
Cancel
Save