diff --git a/electrum/json_db.py b/electrum/json_db.py index 9c3fd6be8..1c989da09 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -166,31 +166,59 @@ class StoredDict(dict): class JsonDB(Logger): - def __init__(self, s: str, storage=None, encoder=None): + def __init__(self, s: str, storage=None, encoder=None, upgrader=None): Logger.__init__(self) self.lock = threading.RLock() self.storage = storage self.encoder = encoder self._modified = False # load data - data, was_upgraded = self.load_data(s) + data = self.load_data(s) + if upgrader: + data, was_upgraded = upgrader(data) + else: + was_upgraded = False # convert to StoredDict self.data = StoredDict(data, self, []) + # note: self._modified may have been affected by StoredDict + self._modified = was_upgraded # write file in case there was a db upgrade - if was_upgraded and self.storage and self.storage.file_exists(): + if self.storage and self.storage.file_exists(): self.write() def load_data(self, s:str) -> dict: """ overloaded in wallet_db """ if s == '': - return {}, False + return {} try: data = json.loads(s) except Exception: - raise WalletFileException("Cannot read wallet file. (parsing failed)") + if r := self.maybe_load_ast_data(s): + data = r + else: + raise WalletFileException("Cannot read wallet file. (parsing failed)") if not isinstance(data, dict): raise WalletFileException("Malformed wallet file (not dict)") - return data, False + return data + + def maybe_load_ast_data(self, s): + """ for old wallets """ + try: + import ast + d = ast.literal_eval(s) + labels = d.get('labels', {}) + except Exception as e: + return + data = {} + for key, value in d.items(): + try: + json.dumps(key) + json.dumps(value) + except Exception: + self.logger.info(f'Failed to convert label to json format: {key}') + continue + data[key] = value + return data def set_modified(self, b): with self.lock: diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 1ed47434c..f71935732 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -32,6 +32,7 @@ from collections import defaultdict from typing import Dict, Optional, List, Tuple, Set, Iterable, NamedTuple, Sequence, TYPE_CHECKING, Union import binascii import time +from functools import partial import attr @@ -1172,60 +1173,41 @@ class WalletDBUpgrader(Logger): raise WalletFileException(msg) +def upgrade_wallet_db(data: dict, do_upgrade) -> Tuple[dict, bool]: + was_upgraded = False + + if len(data) == 0: + # create new DB + data['seed_version'] = FINAL_SEED_VERSION + # store this for debugging purposes + v = DBMetadata( + creation_timestamp=int(time.time()), + first_electrum_version_used=ELECTRUM_VERSION, + ) + assert data.get("db_metadata", None) is None + data["db_metadata"] = v + was_upgraded = True + + dbu = WalletDBUpgrader(data) + if dbu.requires_split(): + raise WalletRequiresSplit(dbu.get_split_accounts()) + if dbu.requires_upgrade() and do_upgrade: + dbu.upgrade() + was_upgraded = True + if dbu.requires_upgrade(): + raise WalletRequiresUpgrade() + return dbu.data, was_upgraded + + class WalletDB(JsonDB): def __init__(self, s, *, storage=None, upgrade=False): - self._upgrade = upgrade - JsonDB.__init__(self, s, storage, encoder=MyEncoder) + JsonDB.__init__(self, s, storage, encoder=MyEncoder, upgrader=partial(upgrade_wallet_db, do_upgrade=upgrade)) # create pointers self.load_transactions() # load plugins that are conditional on wallet type self.load_plugins() - def load_data(self, s): - try: - data, _ = JsonDB.load_data(self, s) - except Exception: - try: - d = ast.literal_eval(s) - labels = d.get('labels', {}) - except Exception as e: - raise WalletFileException("Cannot read wallet file. (parsing failed)") - data = {} - for key, value in d.items(): - try: - json.dumps(key) - json.dumps(value) - except Exception: - self.logger.info(f'Failed to convert label to json format: {key}') - continue - data[key] = value - if not isinstance(data, dict): - raise WalletFileException("Malformed wallet file (not dict)") - - if len(data) == 0: - # create new DB - data['seed_version'] = FINAL_SEED_VERSION - # store this for debugging purposes - v = DBMetadata( - creation_timestamp=int(time.time()), - first_electrum_version_used=ELECTRUM_VERSION, - ) - assert data.get("db_metadata", None) is None - data["db_metadata"] = v - - dbu = WalletDBUpgrader(data) - was_upgraded = False - if dbu.requires_split(): - raise WalletRequiresSplit(dbu.get_split_accounts()) - if dbu.requires_upgrade() and self._upgrade: - dbu.upgrade() - was_upgraded = True - if dbu.requires_upgrade(): - raise WalletRequiresUpgrade() - return dbu.data, was_upgraded - - @locked def get_seed_version(self): return self.get('seed_version')