Browse Source

wallet: add _addr_map, paralleling _script_map

Hoist _populate_script_map from BIP32Wallet into BaseWallet, rename it
to _populate_maps, and have it populate the new _addr_map in addition to
the existing _script_map. Have the constructor of each concrete wallet
subclass pass to _populate_maps the paths it contributes. Additionally,
do not implement yield_known_paths by iterating over _script_map, but
rather have each wallet subclass contribute its own paths to the
generator returned by yield_known_paths.
master
Matt Whitlock 2 years ago
parent
commit
01ec2a4181
  1. 117
      src/jmclient/wallet.py
  2. 12
      test/jmclient/test_wallet.py

117
src/jmclient/wallet.py

@ -382,6 +382,8 @@ class BaseWallet(object):
# {script: path}, should always hold mappings for all "known" keys
self._script_map = {}
# {address: path}, should always hold mappings for all "known" keys
self._addr_map = {}
self._load_storage()
@ -535,8 +537,7 @@ class BaseWallet(object):
"""
There should be no reason for code outside the wallet to need a privkey.
"""
script = self._ENGINE.address_to_script(addr)
path = self.script_to_path(script)
path = self.addr_to_path(addr)
privkey = self._get_key_from_path(path)[0]
return privkey
@ -1046,8 +1047,8 @@ class BaseWallet(object):
returns:
bool
"""
script = self.addr_to_script(addr)
return script in self._script_map
assert isinstance(addr, str)
return addr in self._addr_map
def is_known_script(self, script):
"""
@ -1062,8 +1063,8 @@ class BaseWallet(object):
return script in self._script_map
def get_addr_mixdepth(self, addr):
script = self.addr_to_script(addr)
return self.get_script_mixdepth(script)
path = self.addr_to_path(addr)
return self._get_mixdepth_from_path(path)
def get_script_mixdepth(self, script):
path = self.script_to_path(script)
@ -1076,16 +1077,27 @@ class BaseWallet(object):
returns:
path generator
"""
for s in self._script_map.values():
yield s
for md in range(self.max_mixdepth + 1):
for path in self.yield_imported_paths(md):
yield path
def _populate_maps(self, paths):
for path in paths:
script = self.get_script_from_path(path)
self._script_map[script] = path
self._addr_map[self.script_to_addr(script)] = path
def addr_to_path(self, addr):
script = self.addr_to_script(addr)
return self.script_to_path(script)
assert isinstance(addr, str)
path = self._addr_map.get(addr)
assert path is not None
return path
def script_to_path(self, script):
assert script in self._script_map
return self._script_map[script]
assert isinstance(script, bytes)
path = self._script_map.get(script)
assert path is not None
return path
def set_next_index(self, mixdepth, address_type, index, force=False):
"""
@ -1775,12 +1787,7 @@ class ImportWalletMixin(object):
for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items():
md = int(md)
self._imported[md] = keys
for index, (key, key_type) in enumerate(keys):
if not key:
# imported key was removed
continue
assert key_type in self._ENGINES
self._cache_imported_key(md, key, key_type, index)
self._populate_maps(self.yield_imported_paths(md))
def save(self):
import_data = {}
@ -1849,8 +1856,8 @@ class ImportWalletMixin(object):
raise Exception("Only one of script|address|path may be given.")
if address:
script = self.addr_to_script(address)
if script:
path = self.addr_to_path(address)
elif script:
path = self.script_to_path(script)
if not path:
@ -1863,18 +1870,18 @@ class ImportWalletMixin(object):
if not script:
script = self.get_script_from_path(path)
if not address:
address = self.script_to_addr(script)
# we need to retain indices
self._imported[path[1]][path[2]] = (b'', -1)
del self._script_map[script]
del self._addr_map[address]
def _cache_imported_key(self, mixdepth, privkey, key_type, index):
engine = self._ENGINES[key_type]
path = (self._IMPORTED_ROOT_PATH, mixdepth, index)
self._script_map[engine.key_to_script(privkey)] = path
self._populate_maps((path,))
return path
def _get_mixdepth_from_path(self, path):
@ -2010,6 +2017,7 @@ class BIP32Wallet(BaseWallet):
def __init__(self, storage, **kwargs):
self._entropy = None
self._key_ident = None
# {mixdepth: {type: index}} with type being 0/1 corresponding
# to external/internal addresses
self._index_cache = None
@ -2028,7 +2036,7 @@ class BIP32Wallet(BaseWallet):
# used to verify paths for sanity checking and for wallet id creation
self._key_ident = b'' # otherwise get_bip32_* won't work
self._key_ident = self._get_key_ident()
self._populate_script_map()
self._populate_maps(self.yield_known_bip32_paths())
self.disable_new_scripts = False
@classmethod
@ -2074,13 +2082,14 @@ class BIP32Wallet(BaseWallet):
self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\
.digest()[:3]
def _populate_script_map(self):
def yield_known_paths(self):
return chain(super().yield_known_paths(), self.yield_known_bip32_paths())
def yield_known_bip32_paths(self):
for md in self._index_cache:
for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID):
for i in range(self._index_cache[md][address_type]):
path = self.get_path(md, address_type, i)
script = self.get_script_from_path(path)
self._script_map[script] = path
yield self.get_path(md, address_type, i)
def save(self):
for md, data in self._index_cache.items():
@ -2115,10 +2124,7 @@ class BIP32Wallet(BaseWallet):
def _get_supported_address_types(cls):
return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID)
def get_script_from_path(self, path):
if not self._is_my_bip32_path(path):
return super().get_script_from_path(path)
def _check_path(self, path):
md, address_type, index = self.get_details(path)
if not 0 <= md <= self.max_mixdepth:
@ -2131,10 +2137,19 @@ class BIP32Wallet(BaseWallet):
and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID:
#special case for timelocked addresses because for them the
#concept of a "next address" cant be used
return self.get_new_script_override_disable(md, address_type)
self._set_index_cache(md, address_type, current_index + 1)
self._populate_maps((path,))
def get_script_from_path(self, path):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_script_from_path(path)
def get_address_from_path(self, path):
if self._is_my_bip32_path(path):
self._check_path(path)
return super().get_address_from_path(path)
def get_path(self, mixdepth=None, address_type=None, index=None):
if mixdepth is not None:
assert isinstance(mixdepth, Integral)
@ -2215,13 +2230,8 @@ class BIP32Wallet(BaseWallet):
if self.disable_new_scripts:
raise RuntimeError("Obtaining new wallet addresses "
+ "disabled, due to nohistory mode")
return self.get_new_script_override_disable(mixdepth, address_type)
def get_new_script_override_disable(self, mixdepth, address_type):
# This is called by get_script_from_path and calls back there. We need to
# ensure all conditions match to avoid endless recursion.
index = self.get_index_cache_and_increment(mixdepth, address_type)
return self.get_script_and_update_map(mixdepth, address_type, index)
index = self._index_cache[mixdepth][address_type]
return self.get_script(mixdepth, address_type, index)
def _set_index_cache(self, mixdepth, address_type, index):
""" Ensures that any update to index_cache dict only applies
@ -2230,17 +2240,6 @@ class BIP32Wallet(BaseWallet):
assert address_type in self._get_supported_address_types()
self._index_cache[mixdepth][address_type] = index
def get_index_cache_and_increment(self, mixdepth, address_type):
cur_index = self._index_cache[mixdepth][address_type]
self._set_index_cache(mixdepth, address_type, cur_index + 1)
return cur_index
def get_script_and_update_map(self, *args):
path = self.get_path(*args)
script = self.get_script_from_path(path)
self._script_map[script] = path
return script
@deprecated
def get_key(self, mixdepth, address_type, index):
path = self.get_path(mixdepth, address_type, index)
@ -2385,6 +2384,10 @@ class FidelityBondMixin(object):
_BIP32_PUBKEY_PREFIX = "fbonds-mpk-"
def __init__(self, storage, **kwargs):
super().__init__(storage, **kwargs)
self._populate_maps(self.yield_fidelity_bond_paths())
@classmethod
def _time_number_to_timestamp(cls, timenumber):
"""
@ -2444,14 +2447,14 @@ class FidelityBondMixin(object):
else:
return False
def _populate_script_map(self):
super()._populate_script_map()
def yield_known_paths(self):
return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths())
def yield_fidelity_bond_paths(self):
md = self.FIDELITY_BOND_MIXDEPTH
address_type = self.BIP32_TIMELOCK_ID
for timenumber in range(self.TIMENUMBER_COUNT):
path = self.get_path(md, address_type, timenumber)
script = self.get_script_from_path(path)
self._script_map[script] = path
yield self.get_path(md, address_type, timenumber)
def add_utxo(self, txid, index, script, value, height=None):
super().add_utxo(txid, index, script, value, height)

12
test/jmclient/test_wallet.py

@ -17,7 +17,6 @@ from jmclient import load_test_config, jm_single, BaseWallet, \
wallet_gettimelockaddress, UnknownAddressForLabel
from test_blockchaininterface import sync_test_wallet
from freezegun import freeze_time
from bitcointx.wallet import CCoinAddressError
pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind")
@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif):
mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
#wallet needs to know about the script beforehand
wallet.get_script_and_update_map(mixdepth, address_type, timenumber)
assert address == wallet.get_addr(mixdepth, address_type, timenumber)
assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber))
@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string):
m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH
address_type = FidelityBondMixin.BIP32_TIMELOCK_ID
script = wallet.get_script_and_update_map(m, address_type, timenumber)
script = wallet.get_script(m, address_type, timenumber)
addr = wallet.script_to_addr(script)
addr_from_method = wallet_gettimelockaddress(wallet, locktime_string)
@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet):
wallet = SegwitWalletFidelityBonds(storage)
timenumber = 0
script = wallet.get_script_and_update_map(
script = wallet.get_script(
FidelityBondMixin.FIDELITY_BOND_MIXDEPTH,
FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber)
utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script))
@ -610,7 +606,9 @@ def test_address_labels(setup_wallet):
wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4")
wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4",
"test")
with pytest.raises(CCoinAddressError):
# we no longer decode addresses just to see if we know about them,
# so we won't get a CCoinAddressError for invalid addresses
#with pytest.raises(CCoinAddressError):
wallet.get_address_label("badaddress")
wallet.set_address_label("badaddress", "test")

Loading…
Cancel
Save