diff --git a/jmclient/jmclient/wallet.py b/jmclient/jmclient/wallet.py index 7385aa1..3778c4a 100644 --- a/jmclient/jmclient/wallet.py +++ b/jmclient/jmclient/wallet.py @@ -299,6 +299,9 @@ class BaseWallet(object): _ENGINE = None + ADDRESS_TYPE_EXTERNAL = 0 + ADDRESS_TYPE_INTERNAL = 1 + def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, mixdepth=None): # to be defined by inheriting classes @@ -446,9 +449,13 @@ class BaseWallet(object): privkey = self._get_priv_from_path(path)[0] return hexlify(privkey).decode('ascii') - def _get_addr_int_ext(self, internal, mixdepth): - script = self.get_internal_script(mixdepth) if internal else \ - self.get_external_script(mixdepth) + def _get_addr_int_ext(self, address_type, mixdepth): + if address_type == self.ADDRESS_TYPE_EXTERNAL: + script = self.get_external_script(mixdepth) + elif address_type == self.ADDRESS_TYPE_INTERNAL: + script = self.get_internal_script(mixdepth) + else: + assert 0 return self.script_to_addr(script) def get_external_addr(self, mixdepth): @@ -457,20 +464,20 @@ class BaseWallet(object): the wallet from other sources, or receiving payments or donations. JoinMarket will never generate these addresses for internal use. """ - return self._get_addr_int_ext(False, mixdepth) + return self._get_addr_int_ext(self.ADDRESS_TYPE_EXTERNAL, mixdepth) def get_internal_addr(self, mixdepth): """ Return an address for internal usage, as change addresses and when participating in transactions initiated by other parties. """ - return self._get_addr_int_ext(True, mixdepth) + return self._get_addr_int_ext(self.ADDRESS_TYPE_INTERNAL, mixdepth) def get_external_script(self, mixdepth): - return self.get_new_script(mixdepth, False) + return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL) def get_internal_script(self, mixdepth): - return self.get_new_script(mixdepth, True) + return self.get_new_script(mixdepth, self.ADDRESS_TYPE_INTERNAL) @classmethod def addr_to_script(cls, addr): @@ -512,40 +519,40 @@ class BaseWallet(object): return cls._ENGINE.pubkey_has_script(pubkey, script) @deprecated - def get_key(self, mixdepth, internal, index): + def get_key(self, mixdepth, address_type, index): raise NotImplementedError() - def get_addr(self, mixdepth, internal, index): - script = self.get_script(mixdepth, internal, index) + def get_addr(self, mixdepth, address_type, index): + script = self.get_script(mixdepth, address_type, index) return self.script_to_addr(script) def get_address_from_path(self, path): script = self.get_script_from_path(path) return self.script_to_addr(script) - def get_new_addr(self, mixdepth, internal): + def get_new_addr(self, mixdepth, address_type): """ use get_external_addr/get_internal_addr """ - script = self.get_new_script(mixdepth, internal) + script = self.get_new_script(mixdepth, address_type) return self.script_to_addr(script) - def get_new_script(self, mixdepth, internal): + def get_new_script(self, mixdepth, address_type): raise NotImplementedError() - def get_wif(self, mixdepth, internal, index): - return self.get_wif_path(self.get_path(mixdepth, internal, index)) + def get_wif(self, mixdepth, address_type, index): + return self.get_wif_path(self.get_path(mixdepth, address_type, index)) def get_wif_path(self, path): priv, engine = self._get_priv_from_path(path) return engine.privkey_to_wif(priv) - def get_path(self, mixdepth=None, internal=None, index=None): + def get_path(self, mixdepth=None, address_type=None, index=None): raise NotImplementedError() def get_details(self, path): """ - Return mixdepth, internal, index for a given path + Return mixdepth, address_type, index for a given path args: path: wallet path @@ -814,8 +821,8 @@ class BaseWallet(object): """ raise NotImplementedError() - def get_script(self, mixdepth, internal, index): - path = self.get_path(mixdepth, internal, index) + def get_script(self, mixdepth, address_type, index): + path = self.get_path(mixdepth, address_type, index) return self.get_script_from_path(path) def _get_priv_from_path(self, path): @@ -843,7 +850,7 @@ class BaseWallet(object): """ raise NotImplementedError() - def get_next_unused_index(self, mixdepth, internal): + def get_next_unused_index(self, mixdepth, address_type): """ Get the next index for public scripts/addresses not yet handed out. @@ -952,13 +959,13 @@ class BaseWallet(object): assert script in self._script_map return self._script_map[script] - def set_next_index(self, mixdepth, internal, index, force=False): + def set_next_index(self, mixdepth, address_type, index, force=False): """ Set the next index to use when generating a new key pair. params: mixdepth: int - internal: 0/False or 1/True + address_type: 0 (external) or 1 (internal) index: int force: True if you know the wallet already knows all scripts up to (excluding) the given index @@ -969,10 +976,11 @@ class BaseWallet(object): def rewind_wallet_indices(self, used_indices, saved_indices): for md in used_indices: - for int_type in (0, 1): - index = max(used_indices[md][int_type], - saved_indices[md][int_type]) - self.set_next_index(md, int_type, index, force=True) + for address_type in (self.ADDRESS_TYPE_EXTERNAL, + self.ADDRESS_TYPE_INTERNAL): + index = max(used_indices[md][address_type], + saved_indices[md][address_type]) + self.set_next_index(md, address_type, index, force=True) def get_used_indices(self, addr_gen): """ Returns a dict of max used indices for each branch in @@ -985,12 +993,13 @@ class BaseWallet(object): for addr in addr_gen: if not self.is_known_addr(addr): continue - md, internal, index = self.get_details( + md, address_type, index = self.get_details( self.addr_to_path(addr)) - if internal not in (0, 1): - assert internal == 'imported' + if address_type not in (self.ADDRESS_TYPE_EXTERNAL, + self.ADDRESS_TYPE_INTERNAL): + assert address_type == 'imported' continue - indices[md][internal] = max(indices[md][internal], index + 1) + indices[md][address_type] = max(indices[md][address_type], index + 1) return indices @@ -1001,9 +1010,10 @@ class BaseWallet(object): cache.""" for md in used_indices: - for internal in (0, 1): - if used_indices[md][internal] >\ - max(self.get_next_unused_index(md, internal), 0): + for address_type in (self.ADDRESS_TYPE_EXTERNAL, + self.ADDRESS_TYPE_INTERNAL): + if used_indices[md][address_type] >\ + max(self.get_next_unused_index(md, address_type), 0): return False return True @@ -1265,13 +1275,14 @@ class BIP32Wallet(BaseWallet): _STORAGE_ENTROPY_KEY = b'entropy' _STORAGE_INDEX_CACHE = b'index_cache' BIP32_MAX_PATH_LEVEL = 2**31 - BIP32_EXT_ID = 0 - BIP32_INT_ID = 1 + BIP32_EXT_ID = BaseWallet.ADDRESS_TYPE_EXTERNAL + BIP32_INT_ID = BaseWallet.ADDRESS_TYPE_INTERNAL ENTROPY_BYTES = 16 def __init__(self, storage, **kwargs): self._entropy = None - # {mixdepth: {type: index}} with type being 0/1 for [non]-internal + # {mixdepth: {type: index}} with type being 0/1 corresponding + # to external/internal addresses self._index_cache = None # path is a tuple of BIP32 levels, # m is the master key's fingerprint @@ -1333,9 +1344,9 @@ class BIP32Wallet(BaseWallet): def _populate_script_map(self): for md in self._index_cache: - for int_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): - for i in range(self._index_cache[md][int_type]): - path = self.get_path(md, int_type, i) + for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): + for i in range(self._index_cache[md][address_type]): + path = self.get_path(md, address_type, i) script = self.get_script_from_path(path) self._script_map[script] = path @@ -1372,43 +1383,42 @@ class BIP32Wallet(BaseWallet): if not self._is_my_bip32_path(path): raise WalletError("unable to get script for unknown key path") - md, int_type, index = self.get_details(path) + md, address_type, index = self.get_details(path) if not 0 <= md <= self.max_mixdepth: raise WalletError("Mixdepth outside of wallet's range.") - assert int_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID) + assert address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID) - current_index = self._index_cache[md][int_type] + current_index = self._index_cache[md][address_type] if index == current_index: - return self.get_new_script_override_disable(md, int_type) + return self.get_new_script_override_disable(md, address_type) priv, engine = self._get_priv_from_path(path) script = engine.privkey_to_script(priv) return script - def get_path(self, mixdepth=None, internal=None, index=None): + def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: assert isinstance(mixdepth, Integral) if not 0 <= mixdepth <= self.max_mixdepth: raise WalletError("Mixdepth outside of wallet's range.") - if internal is not None: + if address_type is not None: if mixdepth is None: - raise Exception("mixdepth must be set if internal is set") - int_type = self._get_internal_type(internal) + raise Exception("mixdepth must be set if address_type is set") if index is not None: assert isinstance(index, Integral) - if internal is None: - raise Exception("internal must be set if index is set") - assert index <= self._index_cache[mixdepth][int_type] + if address_type is None: + raise Exception("address_type must be set if index is set") + assert index <= self._index_cache[mixdepth][address_type] assert index < self.BIP32_MAX_PATH_LEVEL - return tuple(chain(self._get_bip32_export_path(mixdepth, internal), + return tuple(chain(self._get_bip32_export_path(mixdepth, address_type), (index,))) - return tuple(self._get_bip32_export_path(mixdepth, internal)) + return tuple(self._get_bip32_export_path(mixdepth, address_type)) def get_path_repr(self, path): path = list(path) @@ -1462,53 +1472,50 @@ class BIP32Wallet(BaseWallet): def _is_my_bip32_path(self, path): return path[0] == self._key_ident - def get_new_script(self, mixdepth, internal): + def get_new_script(self, mixdepth, address_type): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") - return self.get_new_script_override_disable(mixdepth, internal) + return self.get_new_script_override_disable(mixdepth, address_type) - def get_new_script_override_disable(self, mixdepth, internal): + def get_new_script_override_disable(self, mixdepth, address_type): # This is called by get_script_from_path and calls back there. We need to # ensure all conditions match to avoid endless recursion. - int_type = self._get_internal_type(internal) - index = self._index_cache[mixdepth][int_type] - self._index_cache[mixdepth][int_type] += 1 - path = self.get_path(mixdepth, int_type, index) + index = self._index_cache[mixdepth][address_type] + self._index_cache[mixdepth][address_type] += 1 + path = self.get_path(mixdepth, address_type, index) script = self.get_script_from_path(path) self._script_map[script] = path return script - def get_script(self, mixdepth, internal, index): - path = self.get_path(mixdepth, internal, index) + def get_script(self, mixdepth, address_type, index): + path = self.get_path(mixdepth, address_type, index) return self.get_script_from_path(path) @deprecated - def get_key(self, mixdepth, internal, index): - int_type = self._get_internal_type(internal) - path = self.get_path(mixdepth, int_type, index) + def get_key(self, mixdepth, address_type, index): + path = self.get_path(mixdepth, address_type, index) priv = self._ENGINE.derive_bip32_privkey(self._master_key, path) return hexlify(priv).decode('ascii') - def get_bip32_priv_export(self, mixdepth=None, internal=None): - path = self._get_bip32_export_path(mixdepth, internal) + def get_bip32_priv_export(self, mixdepth=None, address_type=None): + path = self._get_bip32_export_path(mixdepth, address_type) return self._ENGINE.derive_bip32_priv_export(self._master_key, path) - def get_bip32_pub_export(self, mixdepth=None, internal=None): - path = self._get_bip32_export_path(mixdepth, internal) + def get_bip32_pub_export(self, mixdepth=None, address_type=None): + path = self._get_bip32_export_path(mixdepth, address_type) return self._ENGINE.derive_bip32_pub_export(self._master_key, path) - def _get_bip32_export_path(self, mixdepth=None, internal=None): + def _get_bip32_export_path(self, mixdepth=None, address_type=None): if mixdepth is None: - assert internal is None + assert address_type is None path = tuple() else: assert 0 <= mixdepth <= self.max_mixdepth - if internal is None: + if address_type is None: path = (self._get_bip32_mixdepth_path_level(mixdepth),) else: - int_type = self._get_internal_type(internal) - path = (self._get_bip32_mixdepth_path_level(mixdepth), int_type) + path = (self._get_bip32_mixdepth_path_level(mixdepth), address_type) return tuple(chain(self._get_bip32_base_path(), path)) @@ -1519,19 +1526,15 @@ class BIP32Wallet(BaseWallet): def _get_bip32_mixdepth_path_level(cls, mixdepth): return mixdepth - def _get_internal_type(self, is_internal): - return self.BIP32_INT_ID if is_internal else self.BIP32_EXT_ID - - def get_next_unused_index(self, mixdepth, internal): + def get_next_unused_index(self, mixdepth, address_type): assert 0 <= mixdepth <= self.max_mixdepth - int_type = self._get_internal_type(internal) - if self._index_cache[mixdepth][int_type] >= self.BIP32_MAX_PATH_LEVEL: + if self._index_cache[mixdepth][address_type] >= self.BIP32_MAX_PATH_LEVEL: # FIXME: theoretically this should work for up to # self.BIP32_MAX_PATH_LEVEL * 2, no? raise WalletError("All addresses used up, cannot generate new ones.") - return self._index_cache[mixdepth][int_type] + return self._index_cache[mixdepth][address_type] def get_mnemonic_words(self): return ' '.join(mn_encode(hexlify(self._entropy).decode('ascii'))), None @@ -1547,11 +1550,10 @@ class BIP32Wallet(BaseWallet): def get_wallet_id(self): return hexlify(self._key_ident).decode('ascii') - def set_next_index(self, mixdepth, internal, index, force=False): - int_type = self._get_internal_type(internal) - if not (force or index <= self._index_cache[mixdepth][int_type]): + def set_next_index(self, mixdepth, address_type, index, force=False): + if not (force or index <= self._index_cache[mixdepth][address_type]): raise Exception("cannot advance index without force=True") - self._index_cache[mixdepth][int_type] = index + self._index_cache[mixdepth][address_type] = index def get_details(self, path): if not self._is_my_bip32_path(path): diff --git a/jmclient/jmclient/wallet_service.py b/jmclient/jmclient/wallet_service.py index 459bd90..1d75ae9 100644 --- a/jmclient/jmclient/wallet_service.py +++ b/jmclient/jmclient/wallet_service.py @@ -721,16 +721,16 @@ class WalletService(Service): for md in range(self.max_mixdepth + 1): saved_indices[md] = [0, 0] - for internal in (0, 1): - next_unused = self.get_next_unused_index(md, internal) + for address_type in (0, 1): + next_unused = self.get_next_unused_index(md, address_type) for index in range(next_unused): - addresses.add(self.get_addr(md, internal, index)) + addresses.add(self.get_addr(md, address_type, index)) for index in range(self.gap_limit): - addresses.add(self.get_new_addr(md, internal)) + addresses.add(self.get_new_addr(md, address_type)) # reset the indices to the value we had before the # new address calls: - self.set_next_index(md, internal, next_unused) - saved_indices[md][internal] = next_unused + self.set_next_index(md, address_type, next_unused) + saved_indices[md][address_type] = next_unused # include any imported addresses for path in self.yield_imported_paths(md): addresses.add(self.get_address_from_path(path)) @@ -742,11 +742,11 @@ class WalletService(Service): addresses = set() for md in range(self.max_mixdepth + 1): - for internal in (True, False): - old_next = self.get_next_unused_index(md, internal) + for address_type in (1, 0): + old_next = self.get_next_unused_index(md, address_type) for index in range(gap_limit): - addresses.add(self.get_new_addr(md, internal)) - self.set_next_index(md, internal, old_next) + addresses.add(self.get_new_addr(md, address_type)) + self.set_next_index(md, address_type, old_next) return addresses diff --git a/jmclient/jmclient/wallet_utils.py b/jmclient/jmclient/wallet_utils.py index f2b2f55..6a1d4a7 100644 --- a/jmclient/jmclient/wallet_utils.py +++ b/jmclient/jmclient/wallet_utils.py @@ -162,13 +162,13 @@ class WalletViewBase(object): return "{0:.08f}".format(self.get_balance(include_unconf)) class WalletViewEntry(WalletViewBase): - def __init__(self, wallet_path_repr, account, forchange, aindex, addr, amounts, + def __init__(self, wallet_path_repr, account, address_type, aindex, addr, amounts, used = 'new', serclass=str, priv=None, custom_separator=None): super(WalletViewEntry, self).__init__(wallet_path_repr, serclass=serclass, custom_separator=custom_separator) self.account = account - assert forchange in [0, 1, -1] - self.forchange =forchange + assert address_type in [0, 1, -1] + self.address_type = address_type assert isinstance(aindex, Integral) assert aindex >= 0 self.aindex = aindex @@ -213,14 +213,14 @@ class WalletViewEntry(WalletViewBase): return self.serclass(ed) class WalletViewBranch(WalletViewBase): - def __init__(self, wallet_path_repr, account, forchange, branchentries=None, + def __init__(self, wallet_path_repr, account, address_type, branchentries=None, xpub=None, serclass=str, custom_separator=None): super(WalletViewBranch, self).__init__(wallet_path_repr, children=branchentries, serclass=serclass, custom_separator=custom_separator) self.account = account - assert forchange in [0, 1, -1] - self.forchange = forchange + assert address_type in [0, 1, -1] + self.address_type = address_type if xpub: assert xpub.startswith('xpub') or xpub.startswith('tpub') self.xpub = xpub if xpub else "" @@ -238,8 +238,8 @@ class WalletViewBranch(WalletViewBase): return self.serclass(entryseparator.join(lines)) def serialize_branch_header(self): - start = "external addresses" if self.forchange == 0 else "internal addresses" - if self.forchange == -1: + start = "external addresses" if self.address_type == 0 else "internal addresses" + if self.address_type == -1: start = "Imported keys" return self.serclass(self.separator.join([start, self.wallet_path_repr, self.xpub])) @@ -418,32 +418,33 @@ def wallet_display(wallet_service, showprivkey, displayall=False, utxos = wallet_service.get_utxos_by_mixdepth(include_disabled=True, hexfmt=False) for m in range(wallet_service.mixdepth + 1): branchlist = [] - for forchange in [0, 1]: + for address_type in [0, 1]: entrylist = [] - if forchange == 0: + if address_type == 0: # users would only want to hand out the xpub for externals - xpub_key = wallet_service.get_bip32_pub_export(m, forchange) + xpub_key = wallet_service.get_bip32_pub_export(m, address_type) else: xpub_key = "" - unused_index = wallet_service.get_next_unused_index(m, forchange) + unused_index = wallet_service.get_next_unused_index(m, address_type) for k in range(unused_index + wallet_service.gap_limit): - path = wallet_service.get_path(m, forchange, k) + path = wallet_service.get_path(m, address_type, k) addr = wallet_service.get_address_from_path(path) balance, used = get_addr_status( - path, utxos[m], k >= unused_index, forchange) + path, utxos[m], k >= unused_index, address_type) if showprivkey: privkey = wallet_service.get_wif_path(path) else: privkey = '' if (displayall or balance > 0 or - (used == 'new' and forchange == 0)): + (used == 'new' and address_type == 0)): entrylist.append(WalletViewEntry( - wallet_service.get_path_repr(path), m, forchange, k, addr, + wallet_service.get_path_repr(path), m, address_type, k, addr, [balance, balance], priv=privkey, used=used)) - wallet_service.set_next_index(m, forchange, unused_index) - path = wallet_service.get_path_repr(wallet_service.get_path(m, forchange)) - branchlist.append(WalletViewBranch(path, m, forchange, entrylist, + + wallet_service.set_next_index(m, address_type, unused_index) + path = wallet_service.get_path_repr(wallet_service.get_path(m, address_type)) + branchlist.append(WalletViewBranch(path, m, address_type, entrylist, xpub=xpub_key)) ipb = get_imported_privkey_branch(wallet_service, m, showprivkey) if ipb: @@ -1290,15 +1291,15 @@ if __name__ == "__main__": acctlist = [] for a in accounts: branches = [] - for forchange in range(2): + for address_type in range(2): entries = [] for i in range(4): - entries.append(WalletViewEntry(rootpath, a, forchange, + entries.append(WalletViewEntry(rootpath, a, address_type, i, "DUMMYADDRESS"+str(i+a), [i*10000000, i*10000000])) branches.append(WalletViewBranch(rootpath, - a, forchange, branchentries=entries, - xpub="xpubDUMMYXPUB"+str(a+forchange))) + a, address_type, branchentries=entries, + xpub="xpubDUMMYXPUB"+str(a+address_type))) acctlist.append(WalletViewAccount(rootpath, a, branches=branches)) wallet = WalletView(rootpath + "/" + str(walletbranch), accounts=acctlist) diff --git a/jmclient/test/test_wallet.py b/jmclient/test/test_wallet.py index 97e2c50..f424d9f 100644 --- a/jmclient/test/test_wallet.py +++ b/jmclient/test/test_wallet.py @@ -520,7 +520,7 @@ def test_set_next_index(setup_wallet): def test_path_repr(setup_wallet): wallet = get_populated_wallet() - path = wallet.get_path(2, False, 0) + path = wallet.get_path(2, BIP32Wallet.ADDRESS_TYPE_EXTERNAL, 0) path_repr = wallet.get_path_repr(path) path_new = wallet.path_repr_to_path(path_repr)