From 01ec2a41817c8e9c6f6af4bc6e6ad86a687e63ff Mon Sep 17 00:00:00 2001 From: Matt Whitlock Date: Sat, 4 Nov 2023 18:42:03 -0400 Subject: [PATCH] wallet: add _addr_map, paralleling _script_map Hoist _populate_script_map from BIP32Wallet into BaseWallet, rename it to _populate_maps, and have it populate the new _addr_map in addition to the existing _script_map. Have the constructor of each concrete wallet subclass pass to _populate_maps the paths it contributes. Additionally, do not implement yield_known_paths by iterating over _script_map, but rather have each wallet subclass contribute its own paths to the generator returned by yield_known_paths. --- src/jmclient/wallet.py | 117 ++++++++++++++++++----------------- test/jmclient/test_wallet.py | 12 ++-- 2 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index cbaceb6..27c1d55 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -382,6 +382,8 @@ class BaseWallet(object): # {script: path}, should always hold mappings for all "known" keys self._script_map = {} + # {address: path}, should always hold mappings for all "known" keys + self._addr_map = {} self._load_storage() @@ -535,8 +537,7 @@ class BaseWallet(object): """ There should be no reason for code outside the wallet to need a privkey. """ - script = self._ENGINE.address_to_script(addr) - path = self.script_to_path(script) + path = self.addr_to_path(addr) privkey = self._get_key_from_path(path)[0] return privkey @@ -1046,8 +1047,8 @@ class BaseWallet(object): returns: bool """ - script = self.addr_to_script(addr) - return script in self._script_map + assert isinstance(addr, str) + return addr in self._addr_map def is_known_script(self, script): """ @@ -1062,8 +1063,8 @@ class BaseWallet(object): return script in self._script_map def get_addr_mixdepth(self, addr): - script = self.addr_to_script(addr) - return self.get_script_mixdepth(script) + path = self.addr_to_path(addr) + return self._get_mixdepth_from_path(path) def get_script_mixdepth(self, script): path = self.script_to_path(script) @@ -1076,16 +1077,27 @@ class BaseWallet(object): returns: path generator """ - for s in self._script_map.values(): - yield s + for md in range(self.max_mixdepth + 1): + for path in self.yield_imported_paths(md): + yield path + + def _populate_maps(self, paths): + for path in paths: + script = self.get_script_from_path(path) + self._script_map[script] = path + self._addr_map[self.script_to_addr(script)] = path def addr_to_path(self, addr): - script = self.addr_to_script(addr) - return self.script_to_path(script) + assert isinstance(addr, str) + path = self._addr_map.get(addr) + assert path is not None + return path def script_to_path(self, script): - assert script in self._script_map - return self._script_map[script] + assert isinstance(script, bytes) + path = self._script_map.get(script) + assert path is not None + return path def set_next_index(self, mixdepth, address_type, index, force=False): """ @@ -1775,12 +1787,7 @@ class ImportWalletMixin(object): for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): md = int(md) self._imported[md] = keys - for index, (key, key_type) in enumerate(keys): - if not key: - # imported key was removed - continue - assert key_type in self._ENGINES - self._cache_imported_key(md, key, key_type, index) + self._populate_maps(self.yield_imported_paths(md)) def save(self): import_data = {} @@ -1849,8 +1856,8 @@ class ImportWalletMixin(object): raise Exception("Only one of script|address|path may be given.") if address: - script = self.addr_to_script(address) - if script: + path = self.addr_to_path(address) + elif script: path = self.script_to_path(script) if not path: @@ -1863,18 +1870,18 @@ class ImportWalletMixin(object): if not script: script = self.get_script_from_path(path) + if not address: + address = self.script_to_addr(script) # we need to retain indices self._imported[path[1]][path[2]] = (b'', -1) del self._script_map[script] + del self._addr_map[address] def _cache_imported_key(self, mixdepth, privkey, key_type, index): - engine = self._ENGINES[key_type] path = (self._IMPORTED_ROOT_PATH, mixdepth, index) - - self._script_map[engine.key_to_script(privkey)] = path - + self._populate_maps((path,)) return path def _get_mixdepth_from_path(self, path): @@ -2010,6 +2017,7 @@ class BIP32Wallet(BaseWallet): def __init__(self, storage, **kwargs): self._entropy = None + self._key_ident = None # {mixdepth: {type: index}} with type being 0/1 corresponding # to external/internal addresses self._index_cache = None @@ -2028,7 +2036,7 @@ class BIP32Wallet(BaseWallet): # used to verify paths for sanity checking and for wallet id creation self._key_ident = b'' # otherwise get_bip32_* won't work self._key_ident = self._get_key_ident() - self._populate_script_map() + self._populate_maps(self.yield_known_bip32_paths()) self.disable_new_scripts = False @classmethod @@ -2074,13 +2082,14 @@ class BIP32Wallet(BaseWallet): self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\ .digest()[:3] - def _populate_script_map(self): + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_known_bip32_paths()) + + def yield_known_bip32_paths(self): for md in self._index_cache: for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): for i in range(self._index_cache[md][address_type]): - path = self.get_path(md, address_type, i) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, i) def save(self): for md, data in self._index_cache.items(): @@ -2115,10 +2124,7 @@ class BIP32Wallet(BaseWallet): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID) - def get_script_from_path(self, path): - if not self._is_my_bip32_path(path): - return super().get_script_from_path(path) - + def _check_path(self, path): md, address_type, index = self.get_details(path) if not 0 <= md <= self.max_mixdepth: @@ -2131,10 +2137,19 @@ class BIP32Wallet(BaseWallet): and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID: #special case for timelocked addresses because for them the #concept of a "next address" cant be used - return self.get_new_script_override_disable(md, address_type) + self._set_index_cache(md, address_type, current_index + 1) + self._populate_maps((path,)) + def get_script_from_path(self, path): + if self._is_my_bip32_path(path): + self._check_path(path) return super().get_script_from_path(path) + def get_address_from_path(self, path): + if self._is_my_bip32_path(path): + self._check_path(path) + return super().get_address_from_path(path) + def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: assert isinstance(mixdepth, Integral) @@ -2215,13 +2230,8 @@ class BIP32Wallet(BaseWallet): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") - return self.get_new_script_override_disable(mixdepth, address_type) - - def get_new_script_override_disable(self, mixdepth, address_type): - # This is called by get_script_from_path and calls back there. We need to - # ensure all conditions match to avoid endless recursion. - index = self.get_index_cache_and_increment(mixdepth, address_type) - return self.get_script_and_update_map(mixdepth, address_type, index) + index = self._index_cache[mixdepth][address_type] + return self.get_script(mixdepth, address_type, index) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2230,17 +2240,6 @@ class BIP32Wallet(BaseWallet): assert address_type in self._get_supported_address_types() self._index_cache[mixdepth][address_type] = index - def get_index_cache_and_increment(self, mixdepth, address_type): - cur_index = self._index_cache[mixdepth][address_type] - self._set_index_cache(mixdepth, address_type, cur_index + 1) - return cur_index - - def get_script_and_update_map(self, *args): - path = self.get_path(*args) - script = self.get_script_from_path(path) - self._script_map[script] = path - return script - @deprecated def get_key(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -2385,6 +2384,10 @@ class FidelityBondMixin(object): _BIP32_PUBKEY_PREFIX = "fbonds-mpk-" + def __init__(self, storage, **kwargs): + super().__init__(storage, **kwargs) + self._populate_maps(self.yield_fidelity_bond_paths()) + @classmethod def _time_number_to_timestamp(cls, timenumber): """ @@ -2444,14 +2447,14 @@ class FidelityBondMixin(object): else: return False - def _populate_script_map(self): - super()._populate_script_map() + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths()) + + def yield_fidelity_bond_paths(self): md = self.FIDELITY_BOND_MIXDEPTH address_type = self.BIP32_TIMELOCK_ID for timenumber in range(self.TIMENUMBER_COUNT): - path = self.get_path(md, address_type, timenumber) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, timenumber) def add_utxo(self, txid, index, script, value, height=None): super().add_utxo(txid, index, script, value, height) diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index ab68e72..45b23fa 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -17,7 +17,6 @@ from jmclient import load_test_config, jm_single, BaseWallet, \ wallet_gettimelockaddress, UnknownAddressForLabel from test_blockchaininterface import sync_test_wallet from freezegun import freeze_time -from bitcointx.wallet import CCoinAddressError pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") @@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif): mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - #wallet needs to know about the script beforehand - wallet.get_script_and_update_map(mixdepth, address_type, timenumber) - assert address == wallet.get_addr(mixdepth, address_type, timenumber) assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber)) @@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string): m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - script = wallet.get_script_and_update_map(m, address_type, timenumber) + script = wallet.get_script(m, address_type, timenumber) addr = wallet.script_to_addr(script) addr_from_method = wallet_gettimelockaddress(wallet, locktime_string) @@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet): wallet = SegwitWalletFidelityBonds(storage) timenumber = 0 - script = wallet.get_script_and_update_map( + script = wallet.get_script( FidelityBondMixin.FIDELITY_BOND_MIXDEPTH, FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber) utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script)) @@ -610,7 +606,9 @@ def test_address_labels(setup_wallet): wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4") wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4", "test") - with pytest.raises(CCoinAddressError): + # we no longer decode addresses just to see if we know about them, + # so we won't get a CCoinAddressError for invalid addresses + #with pytest.raises(CCoinAddressError): wallet.get_address_label("badaddress") wallet.set_address_label("badaddress", "test")