Browse Source

Merge JoinMarket-Org/joinmarket-clientserver#1616: Allow cache purging at wallet initialization

ef1d76e9db Allow cache purging at wallet initialization (Adam Gibson)

Pull request description:

  Fixes #1614.
  Prior to this commit, if data in the persisted cache in the wallet file were wrong (should be a very extraordinary case), then the joinmarket code would have to crash with a cache invalid warning. After this commit, in such an extraordinary case, the option exists to invalidate or remove the cache on startup, so that it can be rebuilt from scratch. This is done with a config var wallet_caching_disabled in the POLICY section.

ACKs for top commit:
  kristapsk:
    utACK ef1d76e9db

Tree-SHA512: dea2256c7246417ab17be0c89d6b64f49243f35204a8d316ab520f94451d149b3131c928be020f7152f99e0eb5f07843341bb48fca5aaef4c6517701a122b355
master
Kristaps Kaupe 2 years ago
parent
commit
98466ea655
No known key found for this signature in database
GPG Key ID: 33E472FE870C7E5D
  1. 5
      src/jmclient/configure.py
  2. 23
      src/jmclient/wallet.py
  3. 5
      src/jmclient/wallet_utils.py
  4. 50
      test/jmclient/test_wallet.py

5
src/jmclient/configure.py

@ -272,6 +272,11 @@ merge_algorithm = default
# scripts can use the command line flag `-g` instead. # scripts can use the command line flag `-g` instead.
gaplimit = 6 gaplimit = 6
# Disable the caching of addresses and scripts when
# syncing the wallet. You DO NOT need to set this to 'true',
# unless there is an issue of file corruption or a code bug.
wallet_caching_disabled = false
# The fee estimate is based on a projection of how many sats/kilo-vbyte # The fee estimate is based on a projection of how many sats/kilo-vbyte
# are needed to get in one of the next N blocks. N is set here as # are needed to get in one of the next N blocks. N is set here as
# the value of 'tx_fees'. This cost estimate is high if you set # the value of 'tx_fees'. This cost estimate is high if you set

23
src/jmclient/wallet.py

@ -388,7 +388,7 @@ class BaseWallet(object):
ADDRESS_TYPE_INTERNAL = 1 ADDRESS_TYPE_INTERNAL = 1
def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, def __init__(self, storage, gap_limit=6, merge_algorithm_name=None,
mixdepth=None): mixdepth=None, load_cache=True):
# to be defined by inheriting classes # to be defined by inheriting classes
assert self.TYPE is not None assert self.TYPE is not None
assert self._ENGINE is not None assert self._ENGINE is not None
@ -410,7 +410,7 @@ class BaseWallet(object):
# {address: path}, should always hold mappings for all "known" keys # {address: path}, should always hold mappings for all "known" keys
self._addr_map = {} self._addr_map = {}
self._load_storage() self._load_storage(load_cache=load_cache)
assert self._utxos is not None assert self._utxos is not None
assert self._cache is not None assert self._cache is not None
@ -440,7 +440,7 @@ class BaseWallet(object):
def gaplimit(self): def gaplimit(self):
return self.gap_limit return self.gap_limit
def _load_storage(self): def _load_storage(self, load_cache: bool = True) -> None:
""" """
load data from storage load data from storage
""" """
@ -450,7 +450,10 @@ class BaseWallet(object):
self.network = self._storage.data[b'network'].decode('ascii') self.network = self._storage.data[b'network'].decode('ascii')
self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._utxos = UTXOManager(self._storage, self.merge_algorithm)
self._addr_labels = AddressLabelsManager(self._storage) self._addr_labels = AddressLabelsManager(self._storage)
self._cache = self._storage.data.setdefault(b'cache', {}) if load_cache:
self._cache = self._storage.data.setdefault(b'cache', {})
else:
self._cache = {}
def get_storage_location(self): def get_storage_location(self):
""" Return the location of the """ Return the location of the
@ -1893,8 +1896,8 @@ class ImportWalletMixin(object):
# path is (_IMPORTED_ROOT_PATH, mixdepth, key_index) # path is (_IMPORTED_ROOT_PATH, mixdepth, key_index)
super().__init__(storage, **kwargs) super().__init__(storage, **kwargs)
def _load_storage(self): def _load_storage(self, load_cache: bool = True) -> None:
super()._load_storage() super()._load_storage(load_cache=load_cache)
self._imported = collections.defaultdict(list) self._imported = collections.defaultdict(list)
for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items():
md = int(md) md = int(md)
@ -2070,8 +2073,8 @@ class BIP39WalletMixin(object):
_BIP39_EXTENSION_KEY = b'seed_extension' _BIP39_EXTENSION_KEY = b'seed_extension'
MNEMONIC_LANG = 'english' MNEMONIC_LANG = 'english'
def _load_storage(self): def _load_storage(self, load_cache: bool = True) -> None:
super()._load_storage() super()._load_storage(load_cache=load_cache)
self._entropy_extension = self._storage.data.get(self._BIP39_EXTENSION_KEY) self._entropy_extension = self._storage.data.get(self._BIP39_EXTENSION_KEY)
@classmethod @classmethod
@ -2177,8 +2180,8 @@ class BIP32Wallet(BaseWallet):
if write: if write:
storage.save() storage.save()
def _load_storage(self): def _load_storage(self, load_cache: bool = True) -> None:
super()._load_storage() super()._load_storage(load_cache=load_cache)
self._entropy = self._storage.data[self._STORAGE_ENTROPY_KEY] self._entropy = self._storage.data[self._STORAGE_ENTROPY_KEY]
self._index_cache = collections.defaultdict( self._index_cache = collections.defaultdict(

5
src/jmclient/wallet_utils.py

@ -1547,8 +1547,11 @@ def open_wallet(path, ask_for_password=True, password=None, read_only=False,
else: else:
storage = Storage(path, password, read_only=read_only) storage = Storage(path, password, read_only=read_only)
load_cache = True
if jm_single().config.get("POLICY", "wallet_caching_disabled") == "true":
load_cache = False
wallet_cls = get_wallet_cls_from_storage(storage) wallet_cls = get_wallet_cls_from_storage(storage)
wallet = wallet_cls(storage, **kwargs) wallet = wallet_cls(storage, load_cache=load_cache, **kwargs)
wallet_sanity_check(wallet) wallet_sanity_check(wallet)
return wallet return wallet

50
test/jmclient/test_wallet.py

@ -12,7 +12,7 @@ from jmclient import load_test_config, jm_single, BaseWallet, \
SegwitLegacyWallet,BIP32Wallet, BIP49Wallet, LegacyWallet,\ SegwitLegacyWallet,BIP32Wallet, BIP49Wallet, LegacyWallet,\
VolatileStorage, get_network, cryptoengine, WalletError,\ VolatileStorage, get_network, cryptoengine, WalletError,\
SegwitWallet, WalletService, SegwitWalletFidelityBonds,\ SegwitWallet, WalletService, SegwitWalletFidelityBonds,\
create_wallet, open_test_wallet_maybe, \ create_wallet, open_test_wallet_maybe, open_wallet, \
FidelityBondMixin, FidelityBondWatchonlyWallet,\ FidelityBondMixin, FidelityBondWatchonlyWallet,\
wallet_gettimelockaddress, UnknownAddressForLabel wallet_gettimelockaddress, UnknownAddressForLabel
from test_blockchaininterface import sync_test_wallet from test_blockchaininterface import sync_test_wallet
@ -23,7 +23,7 @@ pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind")
testdir = os.path.dirname(os.path.realpath(__file__)) testdir = os.path.dirname(os.path.realpath(__file__))
test_create_wallet_filename = "testwallet_for_create_wallet_test" test_create_wallet_filename = "testwallet_for_create_wallet_test"
test_cache_cleared_filename = "testwallet_for_cache_clear_test"
log = get_log() log = get_log()
@ -764,6 +764,50 @@ def test_wallet_id(setup_wallet):
assert wallet1.get_wallet_id() == wallet2.get_wallet_id() assert wallet1.get_wallet_id() == wallet2.get_wallet_id()
def test_cache_cleared(setup_wallet):
# test plan:
# 1. create a new wallet and sync from scratch
# 2. read its cache as an object
# 3. close the wallet, reopen it, sync it.
# 4. corrupt its cache and save.
# 5. Re open the wallet with recoversync
# and check that the corrupted data is not present.
if os.path.exists(test_cache_cleared_filename):
os.remove(test_cache_cleared_filename)
wallet = create_wallet(test_cache_cleared_filename,
b"hunter2", 2, SegwitWallet)
# note: we use the WalletService as an encapsulation
# of the wallet here because we want to be able to sync,
# but we do not actually start the service and go into
# the monitoring loop.
wallet_service = WalletService(wallet)
# default fast sync, no coins, so no loop
wallet_service.sync_wallet()
wallet_service.update_blockheight()
# to get the cache to save, we need to
# use an address:
addr = wallet_service.get_new_addr(0,0)
jm_single().bc_interface.grab_coins(addr, 1.0)
wallet_service.transaction_monitor()
path_to_corrupt = list(wallet._cache.keys())[0]
# we'll just corrupt the first address and script:
entry_to_corrupt = wallet._cache[path_to_corrupt][b"84'"][b"1'"][b"0'"][b'0'][b'0']
entry_to_corrupt[b'A'] = "notanaddress"
entry_to_corrupt[b'S'] = "notascript"
wallet_service.wallet.save()
wallet_service.wallet.close()
jm_single().config.set("POLICY", "wallet_caching_disabled", "true")
wallet2 = open_wallet(test_cache_cleared_filename,
ask_for_password=False,
password=b"hunter2")
jm_single().config.set("POLICY", "wallet_caching_disabled", "false")
wallet_service2 = WalletService(wallet2)
while not wallet_service2.synced:
wallet_service2.sync_wallet(fast=False)
wallet_service.transaction_monitor()
# we ignored the corrupt cache?
assert wallet_service2.get_balance_at_mixdepth(0) == 10 ** 8
def test_addr_script_conversion(setup_wallet): def test_addr_script_conversion(setup_wallet):
wallet = get_populated_wallet(num=1) wallet = get_populated_wallet(num=1)
@ -1016,4 +1060,6 @@ def setup_wallet(request):
def teardown(): def teardown():
if os.path.exists(test_create_wallet_filename): if os.path.exists(test_create_wallet_filename):
os.remove(test_create_wallet_filename) os.remove(test_create_wallet_filename)
if os.path.exists(test_cache_cleared_filename):
os.remove(test_cache_cleared_filename)
request.addfinalizer(teardown) request.addfinalizer(teardown)

Loading…
Cancel
Save