Browse Source

storage: encapsulate type conversions of stored objects using

decorators (instead of overloading JsonDB._convert_dict and
 _convert_value)
 - stored_in for elements of a StoreDict
 - stored_as for singletons
 - extra register methods are defined for key conversions

This commit was adapted from the jsonpatch branch
master
ThomasV 4 years ago
parent
commit
295734fc53
  1. 4
      electrum/invoices.py
  2. 65
      electrum/json_db.py
  3. 42
      electrum/lnutil.py
  4. 3
      electrum/submarine_swaps.py
  5. 1
      electrum/transaction.py
  6. 3
      electrum/util.py
  7. 76
      electrum/wallet_db.py

4
electrum/invoices.py

@ -4,7 +4,7 @@ from decimal import Decimal
import attr import attr
from .json_db import StoredObject from .json_db import StoredObject, stored_in
from .i18n import _ from .i18n import _
from .util import age, InvoiceError, format_satoshis from .util import age, InvoiceError, format_satoshis
from .lnutil import hex_to_bytes from .lnutil import hex_to_bytes
@ -244,6 +244,7 @@ class BaseInvoice(StoredObject):
return d return d
@stored_in('invoices')
@attr.s @attr.s
class Invoice(BaseInvoice): class Invoice(BaseInvoice):
lightning_invoice = attr.ib(type=str, kw_only=True) # type: Optional[str] lightning_invoice = attr.ib(type=str, kw_only=True) # type: Optional[str]
@ -303,6 +304,7 @@ class Invoice(BaseInvoice):
return d return d
@stored_in('payment_requests')
@attr.s @attr.s
class Request(BaseInvoice): class Request(BaseInvoice):
payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) # type: Optional[bytes] payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) # type: Optional[bytes]

65
electrum/json_db.py

@ -45,6 +45,28 @@ def locked(func):
return wrapper return wrapper
registered_names = {}
registered_dicts = {}
registered_dict_keys = {}
registered_parent_keys = {}
def stored_as(name, _type=dict):
""" decorator that indicates the storage key of a stored object"""
def decorator(func):
registered_names[name] = func, _type
return func
return decorator
def stored_in(name, _type=dict):
""" decorator that indicates the storage key of an element in a StoredDict"""
def decorator(func):
registered_dicts[name] = func, _type
return func
return decorator
class StoredObject: class StoredObject:
db = None db = None
@ -195,3 +217,46 @@ class JsonDB(Logger):
def _should_convert_to_stored_dict(self, key) -> bool: def _should_convert_to_stored_dict(self, key) -> bool:
return True return True
def register_dict(self, name, method, _type):
registered_dicts[name] = method, _type
def register_name(self, name, method, _type):
registered_names[name] = method, _type
def register_dict_key(self, name, method):
registered_dict_keys[name] = method
def register_parent_key(self, name, method):
registered_parent_keys[name] = method
def _convert_dict(self, path, key, v):
if key in registered_dicts:
constructor, _type = registered_dicts[key]
if _type == dict:
v = dict((k, constructor(**x)) for k, x in v.items())
elif _type == tuple:
v = dict((k, constructor(*x)) for k, x in v.items())
else:
v = dict((k, constructor(x)) for k, x in v.items())
if key in registered_dict_keys:
convert_key = registered_dict_keys[key]
elif path and path[-1] in registered_parent_keys:
convert_key = registered_parent_keys.get(path[-1])
else:
convert_key = None
if convert_key:
v = dict((convert_key(k), x) for k, x in v.items())
return v
def _convert_value(self, path, key, v):
if key in registered_names:
constructor, _type = registered_names[key]
if _type == dict:
v = constructor(**v)
else:
v = constructor(v)
return v

42
electrum/lnutil.py

@ -52,7 +52,7 @@ DUST_LIMIT_MAX = 1000
def ln_dummy_address(): def ln_dummy_address():
return redeem_script_to_address('p2wsh', '') return redeem_script_to_address('p2wsh', '')
from .json_db import StoredObject from .json_db import StoredObject, stored_in, stored_as
def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]: def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]:
@ -181,6 +181,7 @@ class ChannelConfig(StoredObject):
raise Exception(f"feerate lower than min relay fee. {initial_feerate_per_kw} sat/kw.") raise Exception(f"feerate lower than min relay fee. {initial_feerate_per_kw} sat/kw.")
@stored_as('local_config')
@attr.s @attr.s
class LocalConfig(ChannelConfig): class LocalConfig(ChannelConfig):
channel_seed = attr.ib(type=bytes, converter=hex_to_bytes) # type: Optional[bytes] channel_seed = attr.ib(type=bytes, converter=hex_to_bytes) # type: Optional[bytes]
@ -214,17 +215,20 @@ class LocalConfig(ChannelConfig):
if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN: if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN:
raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}") raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}")
@stored_as('remote_config')
@attr.s @attr.s
class RemoteConfig(ChannelConfig): class RemoteConfig(ChannelConfig):
next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes) next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes)
current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes) current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes)
@stored_in('fee_updates')
@attr.s @attr.s
class FeeUpdate(StoredObject): class FeeUpdate(StoredObject):
rate = attr.ib(type=int) # in sat/kw rate = attr.ib(type=int) # in sat/kw
ctn_local = attr.ib(default=None, type=int) ctn_local = attr.ib(default=None, type=int)
ctn_remote = attr.ib(default=None, type=int) ctn_remote = attr.ib(default=None, type=int)
@stored_as('constraints')
@attr.s @attr.s
class ChannelConstraints(StoredObject): class ChannelConstraints(StoredObject):
capacity = attr.ib(type=int) # in sat capacity = attr.ib(type=int) # in sat
@ -248,10 +252,12 @@ class ChannelBackupStorage(StoredObject):
chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index) chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index)
return chan_id return chan_id
@stored_in('onchain_channel_backups')
@attr.s @attr.s
class OnchainChannelBackupStorage(ChannelBackupStorage): class OnchainChannelBackupStorage(ChannelBackupStorage):
node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes) node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes)
@stored_in('imported_channel_backups')
@attr.s @attr.s
class ImportedChannelBackupStorage(ChannelBackupStorage): class ImportedChannelBackupStorage(ChannelBackupStorage):
node_id = attr.ib(type=bytes, converter=hex_to_bytes) node_id = attr.ib(type=bytes, converter=hex_to_bytes)
@ -320,6 +326,7 @@ class ScriptHtlc(NamedTuple):
# FIXME duplicate of TxOutpoint in transaction.py?? # FIXME duplicate of TxOutpoint in transaction.py??
@stored_as('funding_outpoint')
@attr.s @attr.s
class Outpoint(StoredObject): class Outpoint(StoredObject):
txid = attr.ib(type=str) txid = attr.ib(type=str)
@ -484,8 +491,17 @@ def shachain_derive(element, to_index):
get_per_commitment_secret_from_seed(element.secret, to_index, zeros), get_per_commitment_secret_from_seed(element.secret, to_index, zeros),
to_index) to_index)
ShachainElement = namedtuple("ShachainElement", ["secret", "index"]) class ShachainElement(NamedTuple):
ShachainElement.__str__ = lambda self: f"ShachainElement({self.secret.hex()},{self.index})" secret: bytes
index: int
def __str__(self):
return "ShachainElement(" + self.secret.hex() + "," + str(self.index) + ")"
@stored_in('buckets', tuple)
def read(*x):
return ShachainElement(bfh(x[0]), int(x[1]))
def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes: def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes:
"""Generate per commitment secret.""" """Generate per commitment secret."""
@ -1226,6 +1242,7 @@ class LnFeatures(IntFlag):
return hex(self._value_) return hex(self._value_)
@stored_as('channel_type', _type=None)
class ChannelType(IntFlag): class ChannelType(IntFlag):
OPTION_LEGACY_CHANNEL = 0 OPTION_LEGACY_CHANNEL = 0
OPTION_STATIC_REMOTEKEY = 1 << 12 OPTION_STATIC_REMOTEKEY = 1 << 12
@ -1546,15 +1563,16 @@ class UpdateAddHtlc:
timestamp = attr.ib(type=int, kw_only=True) timestamp = attr.ib(type=int, kw_only=True)
htlc_id = attr.ib(type=int, kw_only=True, default=None) htlc_id = attr.ib(type=int, kw_only=True, default=None)
@classmethod @stored_in('adds', tuple)
def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc': def from_tuple(amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
return cls(amount_msat=amount_msat, return UpdateAddHtlc(
payment_hash=payment_hash, amount_msat=amount_msat,
cltv_expiry=cltv_expiry, payment_hash=payment_hash,
htlc_id=htlc_id, cltv_expiry=cltv_expiry,
timestamp=timestamp) htlc_id=htlc_id,
timestamp=timestamp)
def to_tuple(self):
def to_json(self):
return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp) return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp)

3
electrum/submarine_swaps.py

@ -19,7 +19,7 @@ from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY, ln_dummy_address
from .bitcoin import dust_threshold from .bitcoin import dust_threshold
from .logging import Logger from .logging import Logger
from .lnutil import hex_to_bytes from .lnutil import hex_to_bytes
from .json_db import StoredObject from .json_db import StoredObject, stored_in
from . import constants from . import constants
from .address_synchronizer import TX_HEIGHT_LOCAL from .address_synchronizer import TX_HEIGHT_LOCAL
from .i18n import _ from .i18n import _
@ -87,6 +87,7 @@ class SwapServerError(Exception):
return _("The swap server errored or is unreachable.") return _("The swap server errored or is unreachable.")
@stored_in('submarine_swaps')
@attr.s @attr.s
class SwapData(StoredObject): class SwapData(StoredObject):
is_reverse = attr.ib(type=bool) is_reverse = attr.ib(type=bool)

1
electrum/transaction.py

@ -53,6 +53,7 @@ from .crypto import sha256d
from .logging import get_logger from .logging import get_logger
from .util import ShortID, OldTaskGroup from .util import ShortID, OldTaskGroup
from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address
from .json_db import stored_in
if TYPE_CHECKING: if TYPE_CHECKING:
from .wallet import Abstract_Wallet from .wallet import Abstract_Wallet

3
electrum/util.py

@ -297,9 +297,6 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
# note: this does not get called for namedtuples :( https://bugs.python.org/issue30343 # note: this does not get called for namedtuples :( https://bugs.python.org/issue30343
from .transaction import Transaction, TxOutput from .transaction import Transaction, TxOutput
from .lnutil import UpdateAddHtlc
if isinstance(obj, UpdateAddHtlc):
return obj.to_tuple()
if isinstance(obj, Transaction): if isinstance(obj, Transaction):
return obj.serialize() return obj.serialize()
if isinstance(obj, TxOutput): if isinstance(obj, TxOutput):

76
electrum/wallet_db.py

@ -41,12 +41,10 @@ from .invoices import Invoice, Request
from .keystore import bip44_derivation from .keystore import bip44_derivation
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
from .logging import Logger from .logging import Logger
from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, ChannelType
from .lnutil import ImportedChannelBackupStorage, OnchainChannelBackupStorage from .lnutil import LOCAL, REMOTE, HTLCOwner, ChannelType
from .lnutil import ChannelConstraints, Outpoint, ShachainElement from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as
from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject
from .plugin import run_hook, plugin_loaders from .plugin import run_hook, plugin_loaders
from .submarine_swaps import SwapData
from .version import ELECTRUM_VERSION from .version import ELECTRUM_VERSION
if TYPE_CHECKING: if TYPE_CHECKING:
@ -61,12 +59,14 @@ FINAL_SEED_VERSION = 52 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format # old versions from overwriting new format
@stored_in('tx_fees', tuple)
class TxFeesValue(NamedTuple): class TxFeesValue(NamedTuple):
fee: Optional[int] = None fee: Optional[int] = None
is_calculated_by_us: bool = False is_calculated_by_us: bool = False
num_inputs: Optional[int] = None num_inputs: Optional[int] = None
@stored_as('db_metadata')
@attr.s @attr.s
class DBMetadata(StoredObject): class DBMetadata(StoredObject):
creation_timestamp = attr.ib(default=None, type=int) creation_timestamp = attr.ib(default=None, type=int)
@ -91,6 +91,20 @@ class WalletDB(JsonDB):
def __init__(self, raw, *, manual_upgrades: bool): def __init__(self, raw, *, manual_upgrades: bool):
JsonDB.__init__(self, {}) JsonDB.__init__(self, {})
# register dicts that require value conversions not handled by constructor
self.register_dict('transactions', lambda x: tx_from_any(x, deserialize=False), None)
self.register_dict('prevouts_by_scripthash', lambda x: set(tuple(k) for k in x), None)
self.register_dict('data_loss_protect_remote_pcp', lambda x: bytes.fromhex(x), None)
# register dicts that require key conversion
for key in [
'adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
self.register_dict_key(key, int)
for key in ['log']:
self.register_dict_key(key, lambda x: HTLCOwner(int(x)))
for key in ['locked_in', 'fails', 'settles']:
self.register_parent_key(key, lambda x: HTLCOwner(int(x)))
self._manual_upgrades = manual_upgrades self._manual_upgrades = manual_upgrades
self._called_after_upgrade_tasks = False self._called_after_upgrade_tasks = False
if raw: # loading existing db if raw: # loading existing db
@ -1560,58 +1574,6 @@ class WalletDB(JsonDB):
self.tx_fees.clear() self.tx_fees.clear()
self._prevouts_by_scripthash.clear() self._prevouts_by_scripthash.clear()
def _convert_dict(self, path, key, v):
if key == 'transactions':
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
if key == 'invoices':
v = dict((k, Invoice(**x)) for k, x in v.items())
if key == 'payment_requests':
v = dict((k, Request(**x)) for k, x in v.items())
elif key == 'adds':
v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items())
elif key == 'fee_updates':
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
elif key == 'submarine_swaps':
v = dict((k, SwapData(**x)) for k, x in v.items())
elif key == 'imported_channel_backups':
v = dict((k, ImportedChannelBackupStorage(**x)) for k, x in v.items())
elif key == 'onchain_channel_backups':
v = dict((k, OnchainChannelBackupStorage(**x)) for k, x in v.items())
elif key == 'tx_fees':
v = dict((k, TxFeesValue(*x)) for k, x in v.items())
elif key == 'prevouts_by_scripthash':
v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
elif key == 'buckets':
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
elif key == 'data_loss_protect_remote_pcp':
v = dict((k, bfh(x)) for k, x in v.items())
# convert htlc_id keys to int
if key in ['adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
v = dict((int(k), x) for k, x in v.items())
# convert keys to HTLCOwner
if key == 'log' or (path and path[-1] in ['locked_in', 'fails', 'settles']):
if "1" in v:
v[LOCAL] = v.pop("1")
v[REMOTE] = v.pop("-1")
return v
def _convert_value(self, path, key, v):
if key == 'local_config':
v = LocalConfig(**v)
elif key == 'remote_config':
v = RemoteConfig(**v)
elif key == 'constraints':
v = ChannelConstraints(**v)
elif key == 'funding_outpoint':
v = Outpoint(**v)
elif key == 'channel_type':
v = ChannelType(v)
elif key == 'db_metadata':
v = DBMetadata(**v)
return v
def _should_convert_to_stored_dict(self, key) -> bool: def _should_convert_to_stored_dict(self, key) -> bool:
if key == 'keystore': if key == 'keystore':
return False return False

Loading…
Cancel
Save