Browse Source

storage: encapsulate type conversions of stored objects using

decorators (instead of overloading JsonDB._convert_dict and
 _convert_value)
 - stored_in for elements of a StoreDict
 - stored_as for singletons
 - extra register methods are defined for key conversions

This commit was adapted from the jsonpatch branch
master
ThomasV 4 years ago
parent
commit
295734fc53
  1. 4
      electrum/invoices.py
  2. 65
      electrum/json_db.py
  3. 42
      electrum/lnutil.py
  4. 3
      electrum/submarine_swaps.py
  5. 1
      electrum/transaction.py
  6. 3
      electrum/util.py
  7. 76
      electrum/wallet_db.py

4
electrum/invoices.py

@ -4,7 +4,7 @@ from decimal import Decimal
import attr
from .json_db import StoredObject
from .json_db import StoredObject, stored_in
from .i18n import _
from .util import age, InvoiceError, format_satoshis
from .lnutil import hex_to_bytes
@ -244,6 +244,7 @@ class BaseInvoice(StoredObject):
return d
@stored_in('invoices')
@attr.s
class Invoice(BaseInvoice):
lightning_invoice = attr.ib(type=str, kw_only=True) # type: Optional[str]
@ -303,6 +304,7 @@ class Invoice(BaseInvoice):
return d
@stored_in('payment_requests')
@attr.s
class Request(BaseInvoice):
payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) # type: Optional[bytes]

65
electrum/json_db.py

@ -45,6 +45,28 @@ def locked(func):
return wrapper
registered_names = {}
registered_dicts = {}
registered_dict_keys = {}
registered_parent_keys = {}
def stored_as(name, _type=dict):
""" decorator that indicates the storage key of a stored object"""
def decorator(func):
registered_names[name] = func, _type
return func
return decorator
def stored_in(name, _type=dict):
""" decorator that indicates the storage key of an element in a StoredDict"""
def decorator(func):
registered_dicts[name] = func, _type
return func
return decorator
class StoredObject:
db = None
@ -195,3 +217,46 @@ class JsonDB(Logger):
def _should_convert_to_stored_dict(self, key) -> bool:
return True
def register_dict(self, name, method, _type):
registered_dicts[name] = method, _type
def register_name(self, name, method, _type):
registered_names[name] = method, _type
def register_dict_key(self, name, method):
registered_dict_keys[name] = method
def register_parent_key(self, name, method):
registered_parent_keys[name] = method
def _convert_dict(self, path, key, v):
if key in registered_dicts:
constructor, _type = registered_dicts[key]
if _type == dict:
v = dict((k, constructor(**x)) for k, x in v.items())
elif _type == tuple:
v = dict((k, constructor(*x)) for k, x in v.items())
else:
v = dict((k, constructor(x)) for k, x in v.items())
if key in registered_dict_keys:
convert_key = registered_dict_keys[key]
elif path and path[-1] in registered_parent_keys:
convert_key = registered_parent_keys.get(path[-1])
else:
convert_key = None
if convert_key:
v = dict((convert_key(k), x) for k, x in v.items())
return v
def _convert_value(self, path, key, v):
if key in registered_names:
constructor, _type = registered_names[key]
if _type == dict:
v = constructor(**v)
else:
v = constructor(v)
return v

42
electrum/lnutil.py

@ -52,7 +52,7 @@ DUST_LIMIT_MAX = 1000
def ln_dummy_address():
return redeem_script_to_address('p2wsh', '')
from .json_db import StoredObject
from .json_db import StoredObject, stored_in, stored_as
def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]:
@ -181,6 +181,7 @@ class ChannelConfig(StoredObject):
raise Exception(f"feerate lower than min relay fee. {initial_feerate_per_kw} sat/kw.")
@stored_as('local_config')
@attr.s
class LocalConfig(ChannelConfig):
channel_seed = attr.ib(type=bytes, converter=hex_to_bytes) # type: Optional[bytes]
@ -214,17 +215,20 @@ class LocalConfig(ChannelConfig):
if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN:
raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}")
@stored_as('remote_config')
@attr.s
class RemoteConfig(ChannelConfig):
next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes)
current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes)
@stored_in('fee_updates')
@attr.s
class FeeUpdate(StoredObject):
rate = attr.ib(type=int) # in sat/kw
ctn_local = attr.ib(default=None, type=int)
ctn_remote = attr.ib(default=None, type=int)
@stored_as('constraints')
@attr.s
class ChannelConstraints(StoredObject):
capacity = attr.ib(type=int) # in sat
@ -248,10 +252,12 @@ class ChannelBackupStorage(StoredObject):
chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index)
return chan_id
@stored_in('onchain_channel_backups')
@attr.s
class OnchainChannelBackupStorage(ChannelBackupStorage):
node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes)
@stored_in('imported_channel_backups')
@attr.s
class ImportedChannelBackupStorage(ChannelBackupStorage):
node_id = attr.ib(type=bytes, converter=hex_to_bytes)
@ -320,6 +326,7 @@ class ScriptHtlc(NamedTuple):
# FIXME duplicate of TxOutpoint in transaction.py??
@stored_as('funding_outpoint')
@attr.s
class Outpoint(StoredObject):
txid = attr.ib(type=str)
@ -484,8 +491,17 @@ def shachain_derive(element, to_index):
get_per_commitment_secret_from_seed(element.secret, to_index, zeros),
to_index)
ShachainElement = namedtuple("ShachainElement", ["secret", "index"])
ShachainElement.__str__ = lambda self: f"ShachainElement({self.secret.hex()},{self.index})"
class ShachainElement(NamedTuple):
secret: bytes
index: int
def __str__(self):
return "ShachainElement(" + self.secret.hex() + "," + str(self.index) + ")"
@stored_in('buckets', tuple)
def read(*x):
return ShachainElement(bfh(x[0]), int(x[1]))
def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes:
"""Generate per commitment secret."""
@ -1226,6 +1242,7 @@ class LnFeatures(IntFlag):
return hex(self._value_)
@stored_as('channel_type', _type=None)
class ChannelType(IntFlag):
OPTION_LEGACY_CHANNEL = 0
OPTION_STATIC_REMOTEKEY = 1 << 12
@ -1546,15 +1563,16 @@ class UpdateAddHtlc:
timestamp = attr.ib(type=int, kw_only=True)
htlc_id = attr.ib(type=int, kw_only=True, default=None)
@classmethod
def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
return cls(amount_msat=amount_msat,
payment_hash=payment_hash,
cltv_expiry=cltv_expiry,
htlc_id=htlc_id,
timestamp=timestamp)
def to_tuple(self):
@stored_in('adds', tuple)
def from_tuple(amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc':
return UpdateAddHtlc(
amount_msat=amount_msat,
payment_hash=payment_hash,
cltv_expiry=cltv_expiry,
htlc_id=htlc_id,
timestamp=timestamp)
def to_json(self):
return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp)

3
electrum/submarine_swaps.py

@ -19,7 +19,7 @@ from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY, ln_dummy_address
from .bitcoin import dust_threshold
from .logging import Logger
from .lnutil import hex_to_bytes
from .json_db import StoredObject
from .json_db import StoredObject, stored_in
from . import constants
from .address_synchronizer import TX_HEIGHT_LOCAL
from .i18n import _
@ -87,6 +87,7 @@ class SwapServerError(Exception):
return _("The swap server errored or is unreachable.")
@stored_in('submarine_swaps')
@attr.s
class SwapData(StoredObject):
is_reverse = attr.ib(type=bool)

1
electrum/transaction.py

@ -53,6 +53,7 @@ from .crypto import sha256d
from .logging import get_logger
from .util import ShortID, OldTaskGroup
from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address
from .json_db import stored_in
if TYPE_CHECKING:
from .wallet import Abstract_Wallet

3
electrum/util.py

@ -297,9 +297,6 @@ class MyEncoder(json.JSONEncoder):
def default(self, obj):
# note: this does not get called for namedtuples :( https://bugs.python.org/issue30343
from .transaction import Transaction, TxOutput
from .lnutil import UpdateAddHtlc
if isinstance(obj, UpdateAddHtlc):
return obj.to_tuple()
if isinstance(obj, Transaction):
return obj.serialize()
if isinstance(obj, TxOutput):

76
electrum/wallet_db.py

@ -41,12 +41,10 @@ from .invoices import Invoice, Request
from .keystore import bip44_derivation
from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput
from .logging import Logger
from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, ChannelType
from .lnutil import ImportedChannelBackupStorage, OnchainChannelBackupStorage
from .lnutil import ChannelConstraints, Outpoint, ShachainElement
from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject
from .lnutil import LOCAL, REMOTE, HTLCOwner, ChannelType
from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as
from .plugin import run_hook, plugin_loaders
from .submarine_swaps import SwapData
from .version import ELECTRUM_VERSION
if TYPE_CHECKING:
@ -61,12 +59,14 @@ FINAL_SEED_VERSION = 52 # electrum >= 2.7 will set this to prevent
# old versions from overwriting new format
@stored_in('tx_fees', tuple)
class TxFeesValue(NamedTuple):
fee: Optional[int] = None
is_calculated_by_us: bool = False
num_inputs: Optional[int] = None
@stored_as('db_metadata')
@attr.s
class DBMetadata(StoredObject):
creation_timestamp = attr.ib(default=None, type=int)
@ -91,6 +91,20 @@ class WalletDB(JsonDB):
def __init__(self, raw, *, manual_upgrades: bool):
JsonDB.__init__(self, {})
# register dicts that require value conversions not handled by constructor
self.register_dict('transactions', lambda x: tx_from_any(x, deserialize=False), None)
self.register_dict('prevouts_by_scripthash', lambda x: set(tuple(k) for k in x), None)
self.register_dict('data_loss_protect_remote_pcp', lambda x: bytes.fromhex(x), None)
# register dicts that require key conversion
for key in [
'adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
self.register_dict_key(key, int)
for key in ['log']:
self.register_dict_key(key, lambda x: HTLCOwner(int(x)))
for key in ['locked_in', 'fails', 'settles']:
self.register_parent_key(key, lambda x: HTLCOwner(int(x)))
self._manual_upgrades = manual_upgrades
self._called_after_upgrade_tasks = False
if raw: # loading existing db
@ -1560,58 +1574,6 @@ class WalletDB(JsonDB):
self.tx_fees.clear()
self._prevouts_by_scripthash.clear()
def _convert_dict(self, path, key, v):
if key == 'transactions':
# note: for performance, "deserialize=False" so that we will deserialize these on-demand
v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items())
if key == 'invoices':
v = dict((k, Invoice(**x)) for k, x in v.items())
if key == 'payment_requests':
v = dict((k, Request(**x)) for k, x in v.items())
elif key == 'adds':
v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items())
elif key == 'fee_updates':
v = dict((k, FeeUpdate(**x)) for k, x in v.items())
elif key == 'submarine_swaps':
v = dict((k, SwapData(**x)) for k, x in v.items())
elif key == 'imported_channel_backups':
v = dict((k, ImportedChannelBackupStorage(**x)) for k, x in v.items())
elif key == 'onchain_channel_backups':
v = dict((k, OnchainChannelBackupStorage(**x)) for k, x in v.items())
elif key == 'tx_fees':
v = dict((k, TxFeesValue(*x)) for k, x in v.items())
elif key == 'prevouts_by_scripthash':
v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items())
elif key == 'buckets':
v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items())
elif key == 'data_loss_protect_remote_pcp':
v = dict((k, bfh(x)) for k, x in v.items())
# convert htlc_id keys to int
if key in ['adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets',
'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']:
v = dict((int(k), x) for k, x in v.items())
# convert keys to HTLCOwner
if key == 'log' or (path and path[-1] in ['locked_in', 'fails', 'settles']):
if "1" in v:
v[LOCAL] = v.pop("1")
v[REMOTE] = v.pop("-1")
return v
def _convert_value(self, path, key, v):
if key == 'local_config':
v = LocalConfig(**v)
elif key == 'remote_config':
v = RemoteConfig(**v)
elif key == 'constraints':
v = ChannelConstraints(**v)
elif key == 'funding_outpoint':
v = Outpoint(**v)
elif key == 'channel_type':
v = ChannelType(v)
elif key == 'db_metadata':
v = DBMetadata(**v)
return v
def _should_convert_to_stored_dict(self, key) -> bool:
if key == 'keystore':
return False

Loading…
Cancel
Save