@ -26,6 +26,7 @@ import threading
import copy
import json
from typing import TYPE_CHECKING
import jsonpatch
from . import util
from . util import WalletFileException , profiler
@ -80,22 +81,35 @@ def stored_in(name, _type=dict):
return decorator
def key_path ( path , key ) :
def to_str ( x ) :
if isinstance ( x , int ) :
return str ( int ( x ) )
else :
assert isinstance ( x , str )
return x
return ' / ' + ' / ' . join ( [ to_str ( x ) for x in path + [ to_str ( key ) ] ] )
class StoredObject :
db = None
path = None
def __setattr__ ( self , key , value ) :
if self . db :
self . db . set_modified ( True )
if self . db and key not in [ ' path ' , ' db ' ] and not key . startswith ( ' _ ' ) :
if value != getattr ( self , key ) :
self . db . add_patch ( { ' op ' : ' replace ' , ' path ' : key_path ( self . path , key ) , ' value ' : value } )
object . __setattr__ ( self , key , value )
def set_db ( self , db ) :
def set_db ( self , db , path ) :
self . db = db
self . path = path
def to_json ( self ) :
d = dict ( vars ( self ) )
d . pop ( ' db ' , None )
d . pop ( ' path ' , None )
# don't expose/store private stuff
d = { k : v for k , v in d . items ( )
if not k . startswith ( ' _ ' ) }
@ -112,20 +126,22 @@ class StoredDict(dict):
self . path = path
# recursively convert dicts to StoredDict
for k , v in list ( data . items ( ) ) :
self . __setitem__ ( k , v )
self . __setitem__ ( k , v , patch = False )
@locked
def __setitem__ ( self , key , v ) :
def __setitem__ ( self , key , v , patch = True ) :
is_new = key not in self
# early return to prevent unnecessary disk writes
if not is_new and self [ key ] == v :
return
if not is_new and patch :
if self . db and json . dumps ( v , cls = self . db . encoder ) == json . dumps ( self [ key ] , cls = self . db . encoder ) :
return
# recursively set db and path
if isinstance ( v , StoredDict ) :
#assert v.db is None
v . db = self . db
v . path = self . path + [ key ]
for k , vv in v . items ( ) :
v [ k ] = vv
v . __setitem__ ( k , vv , patch = False )
# recursively convert dict to StoredDict.
# _convert_dict is called breadth-first
elif isinstance ( v , dict ) :
@ -139,29 +155,57 @@ class StoredDict(dict):
v = self . db . _convert_value ( self . path , key , v )
# set parent of StoredObject
if isinstance ( v , StoredObject ) :
v . set_db ( self . db )
v . set_db ( self . db , self . path + [ key ] )
# convert lists
if isinstance ( v , list ) :
v = StoredList ( v , self . db , self . path + [ key ] )
# set item
dict . __setitem__ ( self , key , v )
if self . db :
self . db . set_modified ( True )
if self . db and patch :
op = ' add ' if is_new else ' replace '
self . db . add_patch ( { ' op ' : op , ' path ' : key_path ( self . path , key ) , ' value ' : v } )
@locked
def __delitem__ ( self , key ) :
dict . __delitem__ ( self , key )
if self . db :
self . db . set_modified ( True )
self . db . add_patch ( { ' op ' : ' remove ' , ' path ' : key_path ( self . path , key ) } )
@locked
def pop ( self , key , v = _RaiseKeyError ) :
if v is _RaiseKeyError :
r = dict . pop ( self , key )
else :
r = dict . pop ( self , key , v )
if key not in self :
if v is _RaiseKeyError :
raise KeyError ( key )
else :
return v
r = dict . pop ( self , key )
if self . db :
self . db . set_modified ( True )
self . db . add_patch ( { ' op ' : ' remove ' , ' path ' : key_path ( self . path , key ) } )
return r
class StoredList ( list ) :
def __init__ ( self , data , db , path ) :
list . __init__ ( self , data )
self . db = db
self . lock = self . db . lock if self . db else threading . RLock ( )
self . path = path
@locked
def append ( self , item ) :
n = len ( self )
list . append ( self , item )
if self . db :
self . db . add_patch ( { ' op ' : ' add ' , ' path ' : key_path ( self . path , ' %d ' % n ) , ' value ' : item } )
@locked
def remove ( self , item ) :
n = self . index ( item )
list . remove ( self , item )
if self . db :
self . db . add_patch ( { ' op ' : ' remove ' , ' path ' : key_path ( self . path , ' %d ' % n ) } )
class JsonDB ( Logger ) :
@ -171,34 +215,39 @@ class JsonDB(Logger):
self . lock = threading . RLock ( )
self . storage = storage
self . encoder = encoder
self . pending_changes = [ ]
self . _modified = False
# load data
data = self . load_data ( s )
if upgrader :
data , was_upgraded = upgrader ( data )
else :
was_upgraded = False
self . _modified | = was_upgraded
# convert to StoredDict
self . data = StoredDict ( data , self , [ ] )
# note: self._modified may have been affected by StoredDict
self . _modified = was_upgraded
# write file in case there was a db upgrade
if self . storage and self . storage . file_exists ( ) :
self . write ( )
self . _ write( )
def load_data ( self , s : str ) - > dict :
""" overloaded in wallet_db """
if s == ' ' :
return { }
try :
data = json . loads ( s )
data = json . loads ( ' [ ' + s + ' ] ' )
data , patches = data [ 0 ] , data [ 1 : ]
except Exception :
if r := self . maybe_load_ast_data ( s ) :
data = r
data , patches = r , [ ]
else :
raise WalletFileException ( " Cannot read wallet file. (parsing failed) " )
if not isinstance ( data , dict ) :
raise WalletFileException ( " Malformed wallet file (not dict) " )
if patches :
# apply patches
self . logger . info ( ' found %d patches ' % len ( patches ) )
patch = jsonpatch . JsonPatch ( patches )
data = patch . apply ( data )
self . set_modified ( True )
return data
def maybe_load_ast_data ( self , s ) :
@ -227,6 +276,11 @@ class JsonDB(Logger):
def modified ( self ) :
return self . _modified
@locked
def add_patch ( self , patch ) :
self . pending_changes . append ( json . dumps ( patch , cls = self . encoder ) )
self . set_modified ( True )
@locked
def get ( self , key , default = None ) :
v = self . data . get ( key )
@ -259,6 +313,12 @@ class JsonDB(Logger):
self . data [ name ] = { }
return self . data [ name ]
@locked
def get_stored_item ( self , key , default ) - > dict :
if key not in self . data :
self . data [ key ] = default
return self . data [ key ]
@locked
def dump ( self , * , human_readable : bool = True ) - > str :
""" Serializes the DB as a string.
@ -302,10 +362,27 @@ class JsonDB(Logger):
v = constructor ( v )
return v
@locked
def write ( self ) :
with self . lock :
if self . storage . file_exists ( ) and not self . storage . is_encrypted ( ) :
self . _append_pending_changes ( )
else :
self . _write ( )
@locked
def _append_pending_changes ( self ) :
if threading . current_thread ( ) . daemon :
self . logger . warning ( ' daemon thread cannot write db ' )
return
if not self . pending_changes :
self . logger . info ( ' no pending changes ' )
return
self . logger . info ( f ' appending { len ( self . pending_changes ) } pending changes ' )
s = ' ' . join ( [ ' , \n ' + x for x in self . pending_changes ] )
self . storage . append ( s )
self . pending_changes = [ ]
@locked
@profiler
def _write ( self ) :
if threading . current_thread ( ) . daemon :
@ -315,4 +392,5 @@ class JsonDB(Logger):
return
json_str = self . dump ( human_readable = not self . storage . is_encrypted ( ) )
self . storage . write ( json_str )
self . pending_changes = [ ]
self . set_modified ( False )