From ef1d76e9db09170db94ac0c10efe7b73866b64e3 Mon Sep 17 00:00:00 2001 From: Adam Gibson Date: Tue, 5 Dec 2023 14:38:34 -0600 Subject: [PATCH] Allow cache purging at wallet initialization Fixes #1614. Prior to this commit, if data in the persisted cache in the wallet file were wrong (should be a very extraordinary case), then the joinmarket code would have to crash with a cache invalid warning. After this commit, in such an extraordinary case, the option exists to invalidate or remove the cache on startup, so that it can be rebuilt from scratch. This is done with a config var wallet_caching_disabled in the POLICY section. --- src/jmclient/configure.py | 5 ++++ src/jmclient/wallet.py | 23 +++++++++-------- src/jmclient/wallet_utils.py | 5 +++- test/jmclient/test_wallet.py | 50 ++++++++++++++++++++++++++++++++++-- 4 files changed, 70 insertions(+), 13 deletions(-) diff --git a/src/jmclient/configure.py b/src/jmclient/configure.py index ac0b29f..14482ff 100644 --- a/src/jmclient/configure.py +++ b/src/jmclient/configure.py @@ -272,6 +272,11 @@ merge_algorithm = default # scripts can use the command line flag `-g` instead. gaplimit = 6 +# Disable the caching of addresses and scripts when +# syncing the wallet. You DO NOT need to set this to 'true', +# unless there is an issue of file corruption or a code bug. +wallet_caching_disabled = false + # The fee estimate is based on a projection of how many sats/kilo-vbyte # are needed to get in one of the next N blocks. N is set here as # the value of 'tx_fees'. This cost estimate is high if you set diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index b417f60..5b59ac2 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -364,7 +364,7 @@ class BaseWallet(object): ADDRESS_TYPE_INTERNAL = 1 def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, - mixdepth=None): + mixdepth=None, load_cache=True): # to be defined by inheriting classes assert self.TYPE is not None assert self._ENGINE is not None @@ -386,7 +386,7 @@ class BaseWallet(object): # {address: path}, should always hold mappings for all "known" keys self._addr_map = {} - self._load_storage() + self._load_storage(load_cache=load_cache) assert self._utxos is not None assert self._cache is not None @@ -416,7 +416,7 @@ class BaseWallet(object): def gaplimit(self): return self.gap_limit - def _load_storage(self): + def _load_storage(self, load_cache: bool = True) -> None: """ load data from storage """ @@ -426,7 +426,10 @@ class BaseWallet(object): self.network = self._storage.data[b'network'].decode('ascii') self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._addr_labels = AddressLabelsManager(self._storage) - self._cache = self._storage.data.setdefault(b'cache', {}) + if load_cache: + self._cache = self._storage.data.setdefault(b'cache', {}) + else: + self._cache = {} def get_storage_location(self): """ Return the location of the @@ -1869,8 +1872,8 @@ class ImportWalletMixin(object): # path is (_IMPORTED_ROOT_PATH, mixdepth, key_index) super().__init__(storage, **kwargs) - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._imported = collections.defaultdict(list) for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): md = int(md) @@ -2046,8 +2049,8 @@ class BIP39WalletMixin(object): _BIP39_EXTENSION_KEY = b'seed_extension' MNEMONIC_LANG = 'english' - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._entropy_extension = self._storage.data.get(self._BIP39_EXTENSION_KEY) @classmethod @@ -2153,8 +2156,8 @@ class BIP32Wallet(BaseWallet): if write: storage.save() - def _load_storage(self): - super()._load_storage() + def _load_storage(self, load_cache: bool = True) -> None: + super()._load_storage(load_cache=load_cache) self._entropy = self._storage.data[self._STORAGE_ENTROPY_KEY] self._index_cache = collections.defaultdict( diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index 786637f..43210e1 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -1547,8 +1547,11 @@ def open_wallet(path, ask_for_password=True, password=None, read_only=False, else: storage = Storage(path, password, read_only=read_only) + load_cache = True + if jm_single().config.get("POLICY", "wallet_caching_disabled") == "true": + load_cache = False wallet_cls = get_wallet_cls_from_storage(storage) - wallet = wallet_cls(storage, **kwargs) + wallet = wallet_cls(storage, load_cache=load_cache, **kwargs) wallet_sanity_check(wallet) return wallet diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 45b23fa..5dffa71 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -12,7 +12,7 @@ from jmclient import load_test_config, jm_single, BaseWallet, \ SegwitLegacyWallet,BIP32Wallet, BIP49Wallet, LegacyWallet,\ VolatileStorage, get_network, cryptoengine, WalletError,\ SegwitWallet, WalletService, SegwitWalletFidelityBonds,\ - create_wallet, open_test_wallet_maybe, \ + create_wallet, open_test_wallet_maybe, open_wallet, \ FidelityBondMixin, FidelityBondWatchonlyWallet,\ wallet_gettimelockaddress, UnknownAddressForLabel from test_blockchaininterface import sync_test_wallet @@ -23,7 +23,7 @@ pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") testdir = os.path.dirname(os.path.realpath(__file__)) test_create_wallet_filename = "testwallet_for_create_wallet_test" - +test_cache_cleared_filename = "testwallet_for_cache_clear_test" log = get_log() @@ -764,6 +764,50 @@ def test_wallet_id(setup_wallet): assert wallet1.get_wallet_id() == wallet2.get_wallet_id() +def test_cache_cleared(setup_wallet): + # test plan: + # 1. create a new wallet and sync from scratch + # 2. read its cache as an object + # 3. close the wallet, reopen it, sync it. + # 4. corrupt its cache and save. + # 5. Re open the wallet with recoversync + # and check that the corrupted data is not present. + if os.path.exists(test_cache_cleared_filename): + os.remove(test_cache_cleared_filename) + wallet = create_wallet(test_cache_cleared_filename, + b"hunter2", 2, SegwitWallet) + # note: we use the WalletService as an encapsulation + # of the wallet here because we want to be able to sync, + # but we do not actually start the service and go into + # the monitoring loop. + wallet_service = WalletService(wallet) + # default fast sync, no coins, so no loop + wallet_service.sync_wallet() + wallet_service.update_blockheight() + # to get the cache to save, we need to + # use an address: + addr = wallet_service.get_new_addr(0,0) + jm_single().bc_interface.grab_coins(addr, 1.0) + wallet_service.transaction_monitor() + path_to_corrupt = list(wallet._cache.keys())[0] + # we'll just corrupt the first address and script: + entry_to_corrupt = wallet._cache[path_to_corrupt][b"84'"][b"1'"][b"0'"][b'0'][b'0'] + entry_to_corrupt[b'A'] = "notanaddress" + entry_to_corrupt[b'S'] = "notascript" + wallet_service.wallet.save() + wallet_service.wallet.close() + jm_single().config.set("POLICY", "wallet_caching_disabled", "true") + wallet2 = open_wallet(test_cache_cleared_filename, + ask_for_password=False, + password=b"hunter2") + jm_single().config.set("POLICY", "wallet_caching_disabled", "false") + wallet_service2 = WalletService(wallet2) + while not wallet_service2.synced: + wallet_service2.sync_wallet(fast=False) + wallet_service.transaction_monitor() + # we ignored the corrupt cache? + assert wallet_service2.get_balance_at_mixdepth(0) == 10 ** 8 + def test_addr_script_conversion(setup_wallet): wallet = get_populated_wallet(num=1) @@ -1016,4 +1060,6 @@ def setup_wallet(request): def teardown(): if os.path.exists(test_create_wallet_filename): os.remove(test_create_wallet_filename) + if os.path.exists(test_cache_cleared_filename): + os.remove(test_cache_cleared_filename) request.addfinalizer(teardown)