Browse Source

wizard: typing

master
Sander van Grieken 2 years ago
parent
commit
b072f5d243
  1. 73
      electrum/wizard.py

73
electrum/wizard.py

@ -1,7 +1,7 @@
import copy import copy
import os import os
from typing import List, NamedTuple, Any, Dict, Optional, Tuple from typing import List, NamedTuple, Any, Dict, Optional, Tuple, TYPE_CHECKING
from electrum.i18n import _ from electrum.i18n import _
from electrum.interface import ServerAddr from electrum.interface import ServerAddr
@ -16,6 +16,11 @@ from electrum import keystore, mnemonic
from electrum import bitcoin from electrum import bitcoin
from electrum.mnemonic import is_any_2fa_seed_type from electrum.mnemonic import is_any_2fa_seed_type
if TYPE_CHECKING:
from electrum.daemon import Daemon
from electrum.plugin import Plugins
from electrum.keystore import Hardware_KeyStore
class WizardViewState(NamedTuple): class WizardViewState(NamedTuple):
view: Optional[str] view: Optional[str]
@ -40,7 +45,7 @@ class AbstractWizard:
self._current = WizardViewState(None, {}, {}) self._current = WizardViewState(None, {}, {})
self._stack = [] # type: List[WizardViewState] self._stack = [] # type: List[WizardViewState]
def navmap_merge(self, additional_navmap): def navmap_merge(self, additional_navmap: dict):
# NOTE: only merges one level deep. Deeper dict levels will overwrite # NOTE: only merges one level deep. Deeper dict levels will overwrite
for k, v in additional_navmap.items(): for k, v in additional_navmap.items():
if k in self.navmap: if k in self.navmap:
@ -55,7 +60,7 @@ class AbstractWizard:
# view params are transient, meant for extra configuration of a view (e.g. info # view params are transient, meant for extra configuration of a view (e.g. info
# msg in a generic choice dialog) # msg in a generic choice dialog)
# exception: stay on this view # exception: stay on this view
def resolve_next(self, view, wizard_data) -> WizardViewState: def resolve_next(self, view: str, wizard_data: dict) -> WizardViewState:
assert view assert view
self._logger.debug(f'view={view}') self._logger.debug(f'view={view}')
assert view in self.navmap assert view in self.navmap
@ -127,7 +132,7 @@ class AbstractWizard:
return self._current return self._current
# check if this view is the final view # check if this view is the final view
def is_last_view(self, view, wizard_data): def is_last_view(self, view: str, wizard_data: dict) -> bool:
assert view assert view
assert view in self.navmap assert view in self.navmap
@ -149,7 +154,7 @@ class AbstractWizard:
else: else:
raise Exception(f'last handler for view {view} is not callable nor a bool literal') raise Exception(f'last handler for view {view} is not callable nor a bool literal')
def finished(self, wizard_data): def finished(self, wizard_data: dict):
self._logger.debug('finished.') self._logger.debug('finished.')
def reset(self): def reset(self):
@ -182,7 +187,7 @@ class AbstractWizard:
return result return result
return sanitize(_stack_item) return sanitize(_stack_item)
def get_wizard_data(self): def get_wizard_data(self) -> dict:
return copy.deepcopy(self._current.wizard_data) return copy.deepcopy(self._current.wizard_data)
@ -190,7 +195,7 @@ class NewWalletWizard(AbstractWizard):
_logger = get_logger(__name__) _logger = get_logger(__name__)
def __init__(self, daemon, plugins): def __init__(self, daemon: 'Daemon', plugins: 'Plugins'):
AbstractWizard.__init__(self) AbstractWizard.__init__(self)
self.navmap = { self.navmap = {
'wallet_name': { 'wallet_name': {
@ -264,36 +269,36 @@ class NewWalletWizard(AbstractWizard):
self._daemon = daemon self._daemon = daemon
self.plugins = plugins self.plugins = plugins
def start(self, initial_data=None): def start(self, initial_data: dict = None) -> WizardViewState:
if initial_data is None: if initial_data is None:
initial_data = {} initial_data = {}
self.reset() self.reset()
self._current = WizardViewState('wallet_name', initial_data, {}) self._current = WizardViewState('wallet_name', initial_data, {})
return self._current return self._current
def is_single_password(self): def is_single_password(self) -> bool:
raise NotImplementedError() raise NotImplementedError()
# returns (sub)dict of current cosigner (or root if first) # returns (sub)dict of current cosigner (or root if first)
def current_cosigner(self, wizard_data): def current_cosigner(self, wizard_data: dict) -> dict:
wdata = wizard_data wdata = wizard_data
if wizard_data['wallet_type'] == 'multisig' and 'multisig_current_cosigner' in wizard_data: if wizard_data['wallet_type'] == 'multisig' and 'multisig_current_cosigner' in wizard_data:
cosigner = wizard_data['multisig_current_cosigner'] cosigner = wizard_data['multisig_current_cosigner']
wdata = wizard_data['multisig_cosigner_data'][str(cosigner)] wdata = wizard_data['multisig_cosigner_data'][str(cosigner)]
return wdata return wdata
def needs_derivation_path(self, wizard_data): def needs_derivation_path(self, wizard_data: dict) -> bool:
wdata = self.current_cosigner(wizard_data) wdata = self.current_cosigner(wizard_data)
return 'seed_variant' in wdata and wdata['seed_variant'] in ['bip39', 'slip39'] return 'seed_variant' in wdata and wdata['seed_variant'] in ['bip39', 'slip39']
def wants_ext(self, wizard_data): def wants_ext(self, wizard_data: dict) -> bool:
wdata = self.current_cosigner(wizard_data) wdata = self.current_cosigner(wizard_data)
return 'seed_variant' in wdata and wdata['seed_extend'] return 'seed_variant' in wdata and wdata['seed_extend']
def is_multisig(self, wizard_data): def is_multisig(self, wizard_data: dict) -> bool:
return wizard_data['wallet_type'] == 'multisig' return wizard_data['wallet_type'] == 'multisig'
def on_wallet_type(self, wizard_data): def on_wallet_type(self, wizard_data: dict) -> str:
t = wizard_data['wallet_type'] t = wizard_data['wallet_type']
return { return {
'standard': 'keystore_type', 'standard': 'keystore_type',
@ -302,7 +307,7 @@ class NewWalletWizard(AbstractWizard):
'imported': 'imported' 'imported': 'imported'
}.get(t) }.get(t)
def on_keystore_type(self, wizard_data): def on_keystore_type(self, wizard_data: dict) -> str:
t = wizard_data['keystore_type'] t = wizard_data['keystore_type']
return { return {
'createseed': 'create_seed', 'createseed': 'create_seed',
@ -311,19 +316,19 @@ class NewWalletWizard(AbstractWizard):
'hardware': 'choose_hardware_device' 'hardware': 'choose_hardware_device'
}.get(t) }.get(t)
def is_hardware(self, wizard_data): def is_hardware(self, wizard_data: dict) -> bool:
return wizard_data['keystore_type'] == 'hardware' return wizard_data['keystore_type'] == 'hardware'
def wallet_password_view(self, wizard_data): def wallet_password_view(self, wizard_data: dict) -> str:
return 'wallet_password_hardware' if self.is_hardware(wizard_data) else 'wallet_password' return 'wallet_password_hardware' if self.is_hardware(wizard_data) else 'wallet_password'
def on_hardware_device(self, wizard_data): def on_hardware_device(self, wizard_data: dict) -> str:
_type, _info = wizard_data['hardware_device'] _type, _info = wizard_data['hardware_device']
run_hook('init_wallet_wizard', self) run_hook('init_wallet_wizard', self)
plugin = self.plugins.get_plugin(_type) plugin = self.plugins.get_plugin(_type)
return plugin.wizard_entry_for_device(_info) return plugin.wizard_entry_for_device(_info)
def on_have_or_confirm_seed(self, wizard_data): def on_have_or_confirm_seed(self, wizard_data: dict) -> str:
if self.needs_derivation_path(wizard_data): if self.needs_derivation_path(wizard_data):
return 'script_and_derivation' return 'script_and_derivation'
elif self.is_multisig(wizard_data): elif self.is_multisig(wizard_data):
@ -331,7 +336,7 @@ class NewWalletWizard(AbstractWizard):
else: else:
return 'wallet_password' return 'wallet_password'
def maybe_master_pubkey(self, wizard_data): def maybe_master_pubkey(self, wizard_data: dict):
self._logger.debug('maybe_master_pubkey') self._logger.debug('maybe_master_pubkey')
if self.needs_derivation_path(wizard_data) and 'derivation_path' not in wizard_data: if self.needs_derivation_path(wizard_data) and 'derivation_path' not in wizard_data:
self._logger.debug('deferred, missing derivation_path') self._logger.debug('deferred, missing derivation_path')
@ -339,7 +344,7 @@ class NewWalletWizard(AbstractWizard):
wizard_data['multisig_master_pubkey'] = self.keystore_from_data(wizard_data['wallet_type'], wizard_data).get_master_public_key() wizard_data['multisig_master_pubkey'] = self.keystore_from_data(wizard_data['wallet_type'], wizard_data).get_master_public_key()
def on_cosigner_keystore_type(self, wizard_data): def on_cosigner_keystore_type(self, wizard_data: dict) -> str:
t = wizard_data['cosigner_keystore_type'] t = wizard_data['cosigner_keystore_type']
return { return {
'key': 'multisig_cosigner_key', 'key': 'multisig_cosigner_key',
@ -347,7 +352,7 @@ class NewWalletWizard(AbstractWizard):
'hardware': 'multisig_cosigner_hardware' 'hardware': 'multisig_cosigner_hardware'
}.get(t) }.get(t)
def on_have_cosigner_seed(self, wizard_data): def on_have_cosigner_seed(self, wizard_data: dict) -> str:
current_cosigner = self.current_cosigner(wizard_data) current_cosigner = self.current_cosigner(wizard_data)
if self.needs_derivation_path(wizard_data) and 'derivation_path' not in current_cosigner: if self.needs_derivation_path(wizard_data) and 'derivation_path' not in current_cosigner:
return 'multisig_cosigner_script_and_derivation' return 'multisig_cosigner_script_and_derivation'
@ -356,7 +361,7 @@ class NewWalletWizard(AbstractWizard):
else: else:
return 'multisig_cosigner_keystore' return 'multisig_cosigner_keystore'
def last_cosigner(self, wizard_data): def last_cosigner(self, wizard_data: dict) -> bool:
# check if we have the final number of cosigners. Doesn't check if cosigner data itself is complete # check if we have the final number of cosigners. Doesn't check if cosigner data itself is complete
# (should be validated by wizardcomponents) # (should be validated by wizardcomponents)
if not self.is_multisig(wizard_data): if not self.is_multisig(wizard_data):
@ -367,7 +372,7 @@ class NewWalletWizard(AbstractWizard):
return True return True
def has_duplicate_masterkeys(self, wizard_data) -> bool: def has_duplicate_masterkeys(self, wizard_data: dict) -> bool:
"""Multisig wallets need distinct master keys. If True, need to prevent wallet-creation.""" """Multisig wallets need distinct master keys. If True, need to prevent wallet-creation."""
xpubs = [self.keystore_from_data(wizard_data['wallet_type'], wizard_data).get_master_public_key()] xpubs = [self.keystore_from_data(wizard_data['wallet_type'], wizard_data).get_master_public_key()]
for cosigner in wizard_data['multisig_cosigner_data']: for cosigner in wizard_data['multisig_cosigner_data']:
@ -376,7 +381,7 @@ class NewWalletWizard(AbstractWizard):
assert xpubs assert xpubs
return len(xpubs) != len(set(xpubs)) return len(xpubs) != len(set(xpubs))
def has_heterogeneous_masterkeys(self, wizard_data) -> bool: def has_heterogeneous_masterkeys(self, wizard_data: dict) -> bool:
"""Multisig wallets need homogeneous master keys. """Multisig wallets need homogeneous master keys.
All master keys need to be bip32, and e.g. Ypub cannot be mixed with Zpub. All master keys need to be bip32, and e.g. Ypub cannot be mixed with Zpub.
If True, need to prevent wallet-creation. If True, need to prevent wallet-creation.
@ -399,7 +404,7 @@ class NewWalletWizard(AbstractWizard):
return True return True
return False return False
def keystore_from_data(self, wallet_type, data): def keystore_from_data(self, wallet_type: str, data: dict):
if 'seed' in data: if 'seed' in data:
if data['seed_variant'] == 'electrum': if data['seed_variant'] == 'electrum':
return keystore.from_seed(data['seed'], data['seed_extra_words'], True) return keystore.from_seed(data['seed'], data['seed_extra_words'], True)
@ -426,7 +431,7 @@ class NewWalletWizard(AbstractWizard):
else: else:
raise Exception('no seed or master_key in data') raise Exception('no seed or master_key in data')
def is_current_cosigner_hardware(self, wizard_data): def is_current_cosigner_hardware(self, wizard_data: dict) -> bool:
cosigner_data = self.current_cosigner(wizard_data) cosigner_data = self.current_cosigner(wizard_data)
cosigner_is_hardware = cosigner_data == wizard_data and wizard_data['keystore_type'] == 'hardware' cosigner_is_hardware = cosigner_data == wizard_data and wizard_data['keystore_type'] == 'hardware'
if 'cosigner_keystore_type' in wizard_data and wizard_data['cosigner_keystore_type'] == 'hardware': if 'cosigner_keystore_type' in wizard_data and wizard_data['cosigner_keystore_type'] == 'hardware':
@ -467,7 +472,7 @@ class NewWalletWizard(AbstractWizard):
return multisig_keys_valid, user_info return multisig_keys_valid, user_info
def validate_seed(self, seed, seed_variant, wallet_type): def validate_seed(self, seed: str, seed_variant: str, wallet_type: str):
seed_type = '' seed_type = ''
seed_valid = False seed_valid = False
validation_message = '' validation_message = ''
@ -506,7 +511,7 @@ class NewWalletWizard(AbstractWizard):
return seed_valid, seed_type, validation_message return seed_valid, seed_type, validation_message
def create_storage(self, path, data): def create_storage(self, path: str, data: dict):
assert data['wallet_type'] in ['standard', '2fa', 'imported', 'multisig'] assert data['wallet_type'] in ['standard', '2fa', 'imported', 'multisig']
if os.path.exists(path): if os.path.exists(path):
@ -635,7 +640,7 @@ class NewWalletWizard(AbstractWizard):
db.load_plugins() db.load_plugins()
db.write() db.write()
def hw_keystore(self, data): def hw_keystore(self, data: dict) -> 'Hardware_KeyStore':
return hardware_keystore({ return hardware_keystore({
'type': 'hardware', 'type': 'hardware',
'hw_type': data['hw_type'], 'hw_type': data['hw_type'],
@ -673,7 +678,7 @@ class ServerConnectWizard(AbstractWizard):
} }
self._daemon = daemon self._daemon = daemon
def do_configure_proxy(self, wizard_data): def do_configure_proxy(self, wizard_data: dict):
proxy_settings = wizard_data['proxy'] proxy_settings = wizard_data['proxy']
if not self._daemon.network: if not self._daemon.network:
self._logger.debug('not configuring proxy, electrum config wants offline mode') self._logger.debug('not configuring proxy, electrum config wants offline mode')
@ -685,7 +690,7 @@ class ServerConnectWizard(AbstractWizard):
net_params = net_params._replace(proxy=proxy_settings) net_params = net_params._replace(proxy=proxy_settings)
self._daemon.network.run_from_another_thread(self._daemon.network.set_parameters(net_params)) self._daemon.network.run_from_another_thread(self._daemon.network.set_parameters(net_params))
def do_configure_server(self, wizard_data): def do_configure_server(self, wizard_data: dict):
self._logger.debug(f'configuring server: {wizard_data!r}') self._logger.debug(f'configuring server: {wizard_data!r}')
net_params = self._daemon.network.get_parameters() net_params = self._daemon.network.get_parameters()
try: try:
@ -697,12 +702,12 @@ class ServerConnectWizard(AbstractWizard):
net_params = net_params._replace(server=server, auto_connect=wizard_data['autoconnect']) net_params = net_params._replace(server=server, auto_connect=wizard_data['autoconnect'])
self._daemon.network.run_from_another_thread(self._daemon.network.set_parameters(net_params)) self._daemon.network.run_from_another_thread(self._daemon.network.set_parameters(net_params))
def do_configure_autoconnect(self, wizard_data): def do_configure_autoconnect(self, wizard_data: dict):
self._logger.debug(f'configuring autoconnect: {wizard_data!r}') self._logger.debug(f'configuring autoconnect: {wizard_data!r}')
if self._daemon.config.cv.NETWORK_AUTO_CONNECT.is_modifiable(): if self._daemon.config.cv.NETWORK_AUTO_CONNECT.is_modifiable():
self._daemon.config.NETWORK_AUTO_CONNECT = wizard_data['autoconnect'] self._daemon.config.NETWORK_AUTO_CONNECT = wizard_data['autoconnect']
def start(self, initial_data=None): def start(self, initial_data: dict = None) -> WizardViewState:
if initial_data is None: if initial_data is None:
initial_data = {} initial_data = {}
self.reset() self.reset()

Loading…
Cancel
Save