Browse Source

json_db: do not overload load_data.

instead, pass an upgrader function to the constructor.
master
ThomasV 2 years ago
parent
commit
019be008d6
  1. 40
      electrum/json_db.py
  2. 74
      electrum/wallet_db.py

40
electrum/json_db.py

@ -166,31 +166,59 @@ class StoredDict(dict):
class JsonDB(Logger):
def __init__(self, s: str, storage=None, encoder=None):
def __init__(self, s: str, storage=None, encoder=None, upgrader=None):
Logger.__init__(self)
self.lock = threading.RLock()
self.storage = storage
self.encoder = encoder
self._modified = False
# load data
data, was_upgraded = self.load_data(s)
data = self.load_data(s)
if upgrader:
data, was_upgraded = upgrader(data)
else:
was_upgraded = False
# convert to StoredDict
self.data = StoredDict(data, self, [])
# note: self._modified may have been affected by StoredDict
self._modified = was_upgraded
# write file in case there was a db upgrade
if was_upgraded and self.storage and self.storage.file_exists():
if self.storage and self.storage.file_exists():
self.write()
def load_data(self, s:str) -> dict:
""" overloaded in wallet_db """
if s == '':
return {}, False
return {}
try:
data = json.loads(s)
except Exception:
raise WalletFileException("Cannot read wallet file. (parsing failed)")
if r := self.maybe_load_ast_data(s):
data = r
else:
raise WalletFileException("Cannot read wallet file. (parsing failed)")
if not isinstance(data, dict):
raise WalletFileException("Malformed wallet file (not dict)")
return data, False
return data
def maybe_load_ast_data(self, s):
""" for old wallets """
try:
import ast
d = ast.literal_eval(s)
labels = d.get('labels', {})
except Exception as e:
return
data = {}
for key, value in d.items():
try:
json.dumps(key)
json.dumps(value)
except Exception:
self.logger.info(f'Failed to convert label to json format: {key}')
continue
data[key] = value
return data
def set_modified(self, b):
with self.lock:

74
electrum/wallet_db.py

@ -32,6 +32,7 @@ from collections import defaultdict
from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union
import binascii
import time
from functools import partial
import attr
@ -1172,60 +1173,41 @@ class WalletDBUpgrader(Logger):
raise WalletFileException(msg)
def upgrade_wallet_db(data: dict, do_upgrade) -> Tuple[dict, bool]:
was_upgraded = False
if len(data) == 0:
# create new DB
data['seed_version'] = FINAL_SEED_VERSION
# store this for debugging purposes
v = DBMetadata(
creation_timestamp=int(time.time()),
first_electrum_version_used=ELECTRUM_VERSION,
)
assert data.get("db_metadata", None) is None
data["db_metadata"] = v
was_upgraded = True
dbu = WalletDBUpgrader(data)
if dbu.requires_split():
raise WalletRequiresSplit(dbu.get_split_accounts())
if dbu.requires_upgrade() and do_upgrade:
dbu.upgrade()
was_upgraded = True
if dbu.requires_upgrade():
raise WalletRequiresUpgrade()
return dbu.data, was_upgraded
class WalletDB(JsonDB):
def __init__(self, s, *, storage=None, upgrade=False):
self._upgrade = upgrade
JsonDB.__init__(self, s, storage, encoder=MyEncoder)
JsonDB.__init__(self, s, storage, encoder=MyEncoder, upgrader=partial(upgrade_wallet_db, do_upgrade=upgrade))
# create pointers
self.load_transactions()
# load plugins that are conditional on wallet type
self.load_plugins()
def load_data(self, s):
try:
data, _ = JsonDB.load_data(self, s)
except Exception:
try:
d = ast.literal_eval(s)
labels = d.get('labels', {})
except Exception as e:
raise WalletFileException("Cannot read wallet file. (parsing failed)")
data = {}
for key, value in d.items():
try:
json.dumps(key)
json.dumps(value)
except Exception:
self.logger.info(f'Failed to convert label to json format: {key}')
continue
data[key] = value
if not isinstance(data, dict):
raise WalletFileException("Malformed wallet file (not dict)")
if len(data) == 0:
# create new DB
data['seed_version'] = FINAL_SEED_VERSION
# store this for debugging purposes
v = DBMetadata(
creation_timestamp=int(time.time()),
first_electrum_version_used=ELECTRUM_VERSION,
)
assert data.get("db_metadata", None) is None
data["db_metadata"] = v
dbu = WalletDBUpgrader(data)
was_upgraded = False
if dbu.requires_split():
raise WalletRequiresSplit(dbu.get_split_accounts())
if dbu.requires_upgrade() and self._upgrade:
dbu.upgrade()
was_upgraded = True
if dbu.requires_upgrade():
raise WalletRequiresUpgrade()
return dbu.data, was_upgraded
@locked
def get_seed_version(self):
return self.get('seed_version')

Loading…
Cancel
Save