Browse Source

json_db: do not overload load_data.

instead, pass an upgrader function to the constructor.
master
ThomasV 2 years ago
parent
commit
019be008d6
  1. 38
      electrum/json_db.py
  2. 46
      electrum/wallet_db.py

38
electrum/json_db.py

@ -166,31 +166,59 @@ class StoredDict(dict):
class JsonDB(Logger): class JsonDB(Logger):
def __init__(self, s: str, storage=None, encoder=None): def __init__(self, s: str, storage=None, encoder=None, upgrader=None):
Logger.__init__(self) Logger.__init__(self)
self.lock = threading.RLock() self.lock = threading.RLock()
self.storage = storage self.storage = storage
self.encoder = encoder self.encoder = encoder
self._modified = False self._modified = False
# load data # load data
data, was_upgraded = self.load_data(s) data = self.load_data(s)
if upgrader:
data, was_upgraded = upgrader(data)
else:
was_upgraded = False
# convert to StoredDict # convert to StoredDict
self.data = StoredDict(data, self, []) self.data = StoredDict(data, self, [])
# note: self._modified may have been affected by StoredDict
self._modified = was_upgraded
# write file in case there was a db upgrade # write file in case there was a db upgrade
if was_upgraded and self.storage and self.storage.file_exists(): if self.storage and self.storage.file_exists():
self.write() self.write()
def load_data(self, s:str) -> dict: def load_data(self, s:str) -> dict:
""" overloaded in wallet_db """ """ overloaded in wallet_db """
if s == '': if s == '':
return {}, False return {}
try: try:
data = json.loads(s) data = json.loads(s)
except Exception: except Exception:
if r := self.maybe_load_ast_data(s):
data = r
else:
raise WalletFileException("Cannot read wallet file. (parsing failed)") raise WalletFileException("Cannot read wallet file. (parsing failed)")
if not isinstance(data, dict): if not isinstance(data, dict):
raise WalletFileException("Malformed wallet file (not dict)") raise WalletFileException("Malformed wallet file (not dict)")
return data, False return data
def maybe_load_ast_data(self, s):
""" for old wallets """
try:
import ast
d = ast.literal_eval(s)
labels = d.get('labels', {})
except Exception as e:
return
data = {}
for key, value in d.items():
try:
json.dumps(key)
json.dumps(value)
except Exception:
self.logger.info(f'Failed to convert label to json format: {key}')
continue
data[key] = value
return data
def set_modified(self, b): def set_modified(self, b):
with self.lock: with self.lock:

46
electrum/wallet_db.py

@ -32,6 +32,7 @@ from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union
import binascii import binascii
import time import time
from functools import partial
import attr import attr
@ -1172,36 +1173,8 @@ class WalletDBUpgrader(Logger):
raise WalletFileException(msg) raise WalletFileException(msg)
class WalletDB(JsonDB): def upgrade_wallet_db(data: dict, do_upgrade) -> Tuple[dict, bool]:
was_upgraded = False
def __init__(self, s, *, storage=None, upgrade=False):
self._upgrade = upgrade
JsonDB.__init__(self, s, storage, encoder=MyEncoder)
# create pointers
self.load_transactions()
# load plugins that are conditional on wallet type
self.load_plugins()
def load_data(self, s):
try:
data, _ = JsonDB.load_data(self, s)
except Exception:
try:
d = ast.literal_eval(s)
labels = d.get('labels', {})
except Exception as e:
raise WalletFileException("Cannot read wallet file. (parsing failed)")
data = {}
for key, value in d.items():
try:
json.dumps(key)
json.dumps(value)
except Exception:
self.logger.info(f'Failed to convert label to json format: {key}')
continue
data[key] = value
if not isinstance(data, dict):
raise WalletFileException("Malformed wallet file (not dict)")
if len(data) == 0: if len(data) == 0:
# create new DB # create new DB
@ -1213,12 +1186,12 @@ class WalletDB(JsonDB):
) )
assert data.get("db_metadata", None) is None assert data.get("db_metadata", None) is None
data["db_metadata"] = v data["db_metadata"] = v
was_upgraded = True
dbu = WalletDBUpgrader(data) dbu = WalletDBUpgrader(data)
was_upgraded = False
if dbu.requires_split(): if dbu.requires_split():
raise WalletRequiresSplit(dbu.get_split_accounts()) raise WalletRequiresSplit(dbu.get_split_accounts())
if dbu.requires_upgrade() and self._upgrade: if dbu.requires_upgrade() and do_upgrade:
dbu.upgrade() dbu.upgrade()
was_upgraded = True was_upgraded = True
if dbu.requires_upgrade(): if dbu.requires_upgrade():
@ -1226,6 +1199,15 @@ class WalletDB(JsonDB):
return dbu.data, was_upgraded return dbu.data, was_upgraded
class WalletDB(JsonDB):
def __init__(self, s, *, storage=None, upgrade=False):
JsonDB.__init__(self, s, storage, encoder=MyEncoder, upgrader=partial(upgrade_wallet_db, do_upgrade=upgrade))
# create pointers
self.load_transactions()
# load plugins that are conditional on wallet type
self.load_plugins()
@locked @locked
def get_seed_version(self): def get_seed_version(self):
return self.get('seed_version') return self.get('seed_version')

Loading…
Cancel
Save