diff --git a/electrum/invoices.py b/electrum/invoices.py index c90ad75a6..b184388b3 100644 --- a/electrum/invoices.py +++ b/electrum/invoices.py @@ -4,7 +4,7 @@ from decimal import Decimal import attr -from .json_db import StoredObject +from .json_db import StoredObject, stored_in from .i18n import _ from .util import age, InvoiceError, format_satoshis from .lnutil import hex_to_bytes @@ -244,6 +244,7 @@ class BaseInvoice(StoredObject): return d +@stored_in('invoices') @attr.s class Invoice(BaseInvoice): lightning_invoice = attr.ib(type=str, kw_only=True) # type: Optional[str] @@ -303,6 +304,7 @@ class Invoice(BaseInvoice): return d +@stored_in('payment_requests') @attr.s class Request(BaseInvoice): payment_hash = attr.ib(type=bytes, kw_only=True, converter=hex_to_bytes) # type: Optional[bytes] diff --git a/electrum/json_db.py b/electrum/json_db.py index 00f249c16..41fadbe8d 100644 --- a/electrum/json_db.py +++ b/electrum/json_db.py @@ -45,6 +45,28 @@ def locked(func): return wrapper +registered_names = {} +registered_dicts = {} +registered_dict_keys = {} +registered_parent_keys = {} + + +def stored_as(name, _type=dict): + """ decorator that indicates the storage key of a stored object""" + def decorator(func): + registered_names[name] = func, _type + return func + return decorator + +def stored_in(name, _type=dict): + """ decorator that indicates the storage key of an element in a StoredDict""" + def decorator(func): + registered_dicts[name] = func, _type + return func + return decorator + + + class StoredObject: db = None @@ -195,3 +217,46 @@ class JsonDB(Logger): def _should_convert_to_stored_dict(self, key) -> bool: return True + + def register_dict(self, name, method, _type): + registered_dicts[name] = method, _type + + def register_name(self, name, method, _type): + registered_names[name] = method, _type + + def register_dict_key(self, name, method): + registered_dict_keys[name] = method + + def register_parent_key(self, name, method): + registered_parent_keys[name] = method + + def _convert_dict(self, path, key, v): + + if key in registered_dicts: + constructor, _type = registered_dicts[key] + if _type == dict: + v = dict((k, constructor(**x)) for k, x in v.items()) + elif _type == tuple: + v = dict((k, constructor(*x)) for k, x in v.items()) + else: + v = dict((k, constructor(x)) for k, x in v.items()) + + if key in registered_dict_keys: + convert_key = registered_dict_keys[key] + elif path and path[-1] in registered_parent_keys: + convert_key = registered_parent_keys.get(path[-1]) + else: + convert_key = None + if convert_key: + v = dict((convert_key(k), x) for k, x in v.items()) + + return v + + def _convert_value(self, path, key, v): + if key in registered_names: + constructor, _type = registered_names[key] + if _type == dict: + v = constructor(**v) + else: + v = constructor(v) + return v diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 4ab8e3949..93ec18b1e 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -52,7 +52,7 @@ DUST_LIMIT_MAX = 1000 def ln_dummy_address(): return redeem_script_to_address('p2wsh', '') -from .json_db import StoredObject +from .json_db import StoredObject, stored_in, stored_as def channel_id_from_funding_tx(funding_txid: str, funding_index: int) -> Tuple[bytes, bytes]: @@ -181,6 +181,7 @@ class ChannelConfig(StoredObject): raise Exception(f"feerate lower than min relay fee. {initial_feerate_per_kw} sat/kw.") +@stored_as('local_config') @attr.s class LocalConfig(ChannelConfig): channel_seed = attr.ib(type=bytes, converter=hex_to_bytes) # type: Optional[bytes] @@ -214,17 +215,20 @@ class LocalConfig(ChannelConfig): if self.htlc_minimum_msat < HTLC_MINIMUM_MSAT_MIN: raise Exception(f"{conf_name}. htlc_minimum_msat too low: {self.htlc_minimum_msat} msat < {HTLC_MINIMUM_MSAT_MIN}") +@stored_as('remote_config') @attr.s class RemoteConfig(ChannelConfig): next_per_commitment_point = attr.ib(type=bytes, converter=hex_to_bytes) current_per_commitment_point = attr.ib(default=None, type=bytes, converter=hex_to_bytes) +@stored_in('fee_updates') @attr.s class FeeUpdate(StoredObject): rate = attr.ib(type=int) # in sat/kw ctn_local = attr.ib(default=None, type=int) ctn_remote = attr.ib(default=None, type=int) +@stored_as('constraints') @attr.s class ChannelConstraints(StoredObject): capacity = attr.ib(type=int) # in sat @@ -248,10 +252,12 @@ class ChannelBackupStorage(StoredObject): chan_id, _ = channel_id_from_funding_tx(self.funding_txid, self.funding_index) return chan_id +@stored_in('onchain_channel_backups') @attr.s class OnchainChannelBackupStorage(ChannelBackupStorage): node_id_prefix = attr.ib(type=bytes, converter=hex_to_bytes) +@stored_in('imported_channel_backups') @attr.s class ImportedChannelBackupStorage(ChannelBackupStorage): node_id = attr.ib(type=bytes, converter=hex_to_bytes) @@ -320,6 +326,7 @@ class ScriptHtlc(NamedTuple): # FIXME duplicate of TxOutpoint in transaction.py?? +@stored_as('funding_outpoint') @attr.s class Outpoint(StoredObject): txid = attr.ib(type=str) @@ -484,8 +491,17 @@ def shachain_derive(element, to_index): get_per_commitment_secret_from_seed(element.secret, to_index, zeros), to_index) -ShachainElement = namedtuple("ShachainElement", ["secret", "index"]) -ShachainElement.__str__ = lambda self: f"ShachainElement({self.secret.hex()},{self.index})" +class ShachainElement(NamedTuple): + secret: bytes + index: int + + def __str__(self): + return "ShachainElement(" + self.secret.hex() + "," + str(self.index) + ")" + + @stored_in('buckets', tuple) + def read(*x): + return ShachainElement(bfh(x[0]), int(x[1])) + def get_per_commitment_secret_from_seed(seed: bytes, i: int, bits: int = 48) -> bytes: """Generate per commitment secret.""" @@ -1226,6 +1242,7 @@ class LnFeatures(IntFlag): return hex(self._value_) +@stored_as('channel_type', _type=None) class ChannelType(IntFlag): OPTION_LEGACY_CHANNEL = 0 OPTION_STATIC_REMOTEKEY = 1 << 12 @@ -1546,15 +1563,16 @@ class UpdateAddHtlc: timestamp = attr.ib(type=int, kw_only=True) htlc_id = attr.ib(type=int, kw_only=True, default=None) - @classmethod - def from_tuple(cls, amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc': - return cls(amount_msat=amount_msat, - payment_hash=payment_hash, - cltv_expiry=cltv_expiry, - htlc_id=htlc_id, - timestamp=timestamp) - - def to_tuple(self): + @stored_in('adds', tuple) + def from_tuple(amount_msat, payment_hash, cltv_expiry, htlc_id, timestamp) -> 'UpdateAddHtlc': + return UpdateAddHtlc( + amount_msat=amount_msat, + payment_hash=payment_hash, + cltv_expiry=cltv_expiry, + htlc_id=htlc_id, + timestamp=timestamp) + + def to_json(self): return (self.amount_msat, self.payment_hash, self.cltv_expiry, self.htlc_id, self.timestamp) diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index a7f966e7e..7918283d5 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -19,7 +19,7 @@ from .lnutil import REDEEM_AFTER_DOUBLE_SPENT_DELAY, ln_dummy_address from .bitcoin import dust_threshold from .logging import Logger from .lnutil import hex_to_bytes -from .json_db import StoredObject +from .json_db import StoredObject, stored_in from . import constants from .address_synchronizer import TX_HEIGHT_LOCAL from .i18n import _ @@ -87,6 +87,7 @@ class SwapServerError(Exception): return _("The swap server errored or is unreachable.") +@stored_in('submarine_swaps') @attr.s class SwapData(StoredObject): is_reverse = attr.ib(type=bool) diff --git a/electrum/transaction.py b/electrum/transaction.py index 154ab55a7..248642905 100644 --- a/electrum/transaction.py +++ b/electrum/transaction.py @@ -53,6 +53,7 @@ from .crypto import sha256d from .logging import get_logger from .util import ShortID, OldTaskGroup from .descriptor import Descriptor, MissingSolutionPiece, create_dummy_descriptor_from_address +from .json_db import stored_in if TYPE_CHECKING: from .wallet import Abstract_Wallet diff --git a/electrum/util.py b/electrum/util.py index d09a31a18..c31ba5711 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -297,9 +297,6 @@ class MyEncoder(json.JSONEncoder): def default(self, obj): # note: this does not get called for namedtuples :( https://bugs.python.org/issue30343 from .transaction import Transaction, TxOutput - from .lnutil import UpdateAddHtlc - if isinstance(obj, UpdateAddHtlc): - return obj.to_tuple() if isinstance(obj, Transaction): return obj.serialize() if isinstance(obj, TxOutput): diff --git a/electrum/wallet_db.py b/electrum/wallet_db.py index 1336b609e..2b38ad7de 100644 --- a/electrum/wallet_db.py +++ b/electrum/wallet_db.py @@ -41,12 +41,10 @@ from .invoices import Invoice, Request from .keystore import bip44_derivation from .transaction import Transaction, TxOutpoint, tx_from_any, PartialTransaction, PartialTxOutput from .logging import Logger -from .lnutil import LOCAL, REMOTE, FeeUpdate, UpdateAddHtlc, LocalConfig, RemoteConfig, ChannelType -from .lnutil import ImportedChannelBackupStorage, OnchainChannelBackupStorage -from .lnutil import ChannelConstraints, Outpoint, ShachainElement -from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject + +from .lnutil import LOCAL, REMOTE, HTLCOwner, ChannelType +from .json_db import StoredDict, JsonDB, locked, modifier, StoredObject, stored_in, stored_as from .plugin import run_hook, plugin_loaders -from .submarine_swaps import SwapData from .version import ELECTRUM_VERSION if TYPE_CHECKING: @@ -61,12 +59,14 @@ FINAL_SEED_VERSION = 52 # electrum >= 2.7 will set this to prevent # old versions from overwriting new format +@stored_in('tx_fees', tuple) class TxFeesValue(NamedTuple): fee: Optional[int] = None is_calculated_by_us: bool = False num_inputs: Optional[int] = None +@stored_as('db_metadata') @attr.s class DBMetadata(StoredObject): creation_timestamp = attr.ib(default=None, type=int) @@ -91,6 +91,20 @@ class WalletDB(JsonDB): def __init__(self, raw, *, manual_upgrades: bool): JsonDB.__init__(self, {}) + # register dicts that require value conversions not handled by constructor + self.register_dict('transactions', lambda x: tx_from_any(x, deserialize=False), None) + self.register_dict('prevouts_by_scripthash', lambda x: set(tuple(k) for k in x), None) + self.register_dict('data_loss_protect_remote_pcp', lambda x: bytes.fromhex(x), None) + # register dicts that require key conversion + for key in [ + 'adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets', + 'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']: + self.register_dict_key(key, int) + for key in ['log']: + self.register_dict_key(key, lambda x: HTLCOwner(int(x))) + for key in ['locked_in', 'fails', 'settles']: + self.register_parent_key(key, lambda x: HTLCOwner(int(x))) + self._manual_upgrades = manual_upgrades self._called_after_upgrade_tasks = False if raw: # loading existing db @@ -1560,58 +1574,6 @@ class WalletDB(JsonDB): self.tx_fees.clear() self._prevouts_by_scripthash.clear() - def _convert_dict(self, path, key, v): - if key == 'transactions': - # note: for performance, "deserialize=False" so that we will deserialize these on-demand - v = dict((k, tx_from_any(x, deserialize=False)) for k, x in v.items()) - if key == 'invoices': - v = dict((k, Invoice(**x)) for k, x in v.items()) - if key == 'payment_requests': - v = dict((k, Request(**x)) for k, x in v.items()) - elif key == 'adds': - v = dict((k, UpdateAddHtlc.from_tuple(*x)) for k, x in v.items()) - elif key == 'fee_updates': - v = dict((k, FeeUpdate(**x)) for k, x in v.items()) - elif key == 'submarine_swaps': - v = dict((k, SwapData(**x)) for k, x in v.items()) - elif key == 'imported_channel_backups': - v = dict((k, ImportedChannelBackupStorage(**x)) for k, x in v.items()) - elif key == 'onchain_channel_backups': - v = dict((k, OnchainChannelBackupStorage(**x)) for k, x in v.items()) - elif key == 'tx_fees': - v = dict((k, TxFeesValue(*x)) for k, x in v.items()) - elif key == 'prevouts_by_scripthash': - v = dict((k, {(prevout, value) for (prevout, value) in x}) for k, x in v.items()) - elif key == 'buckets': - v = dict((k, ShachainElement(bfh(x[0]), int(x[1]))) for k, x in v.items()) - elif key == 'data_loss_protect_remote_pcp': - v = dict((k, bfh(x)) for k, x in v.items()) - # convert htlc_id keys to int - if key in ['adds', 'locked_in', 'settles', 'fails', 'fee_updates', 'buckets', - 'unacked_updates', 'unfulfilled_htlcs', 'fail_htlc_reasons', 'onion_keys']: - v = dict((int(k), x) for k, x in v.items()) - # convert keys to HTLCOwner - if key == 'log' or (path and path[-1] in ['locked_in', 'fails', 'settles']): - if "1" in v: - v[LOCAL] = v.pop("1") - v[REMOTE] = v.pop("-1") - return v - - def _convert_value(self, path, key, v): - if key == 'local_config': - v = LocalConfig(**v) - elif key == 'remote_config': - v = RemoteConfig(**v) - elif key == 'constraints': - v = ChannelConstraints(**v) - elif key == 'funding_outpoint': - v = Outpoint(**v) - elif key == 'channel_type': - v = ChannelType(v) - elif key == 'db_metadata': - v = DBMetadata(**v) - return v - def _should_convert_to_stored_dict(self, key) -> bool: if key == 'keystore': return False