|
|
|
|
@ -51,6 +51,7 @@ from .crypto import sha256d
|
|
|
|
|
from . import ecc |
|
|
|
|
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, |
|
|
|
|
NotFoundChanAnnouncementForUpdate) |
|
|
|
|
from .lnverifier import verify_sig_for_channel_update |
|
|
|
|
from .lnmsg import encode_msg |
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
@ -70,85 +71,83 @@ Base = declarative_base()
|
|
|
|
|
FLAG_DISABLE = 1 << 1 |
|
|
|
|
FLAG_DIRECTION = 1 << 0 |
|
|
|
|
|
|
|
|
|
class ChannelInfo(Base): |
|
|
|
|
__tablename__ = 'channel_info' |
|
|
|
|
short_channel_id = Column(String(64), primary_key=True) |
|
|
|
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
|
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
|
capacity_sat = Column(Integer) |
|
|
|
|
msg_payload_hex = Column(String(1024), nullable=False) |
|
|
|
|
trusted = Column(Boolean, nullable=False) |
|
|
|
|
class ChannelInfo(NamedTuple): |
|
|
|
|
short_channel_id: bytes |
|
|
|
|
node1_id: bytes |
|
|
|
|
node2_id: bytes |
|
|
|
|
capacity_sat: int |
|
|
|
|
msg_payload: bytes |
|
|
|
|
trusted: bool |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_msg(payload): |
|
|
|
|
features = int.from_bytes(payload['features'], 'big') |
|
|
|
|
validate_features(features) |
|
|
|
|
channel_id = payload['short_channel_id'].hex() |
|
|
|
|
node_id_1 = payload['node_id_1'].hex() |
|
|
|
|
node_id_2 = payload['node_id_2'].hex() |
|
|
|
|
channel_id = payload['short_channel_id'] |
|
|
|
|
node_id_1 = payload['node_id_1'] |
|
|
|
|
node_id_2 = payload['node_id_2'] |
|
|
|
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] |
|
|
|
|
msg_payload_hex = encode_msg('channel_announcement', **payload).hex() |
|
|
|
|
msg_payload = encode_msg('channel_announcement', **payload) |
|
|
|
|
capacity_sat = None |
|
|
|
|
return ChannelInfo(short_channel_id = channel_id, node1_id = node_id_1, |
|
|
|
|
node2_id = node_id_2, capacity_sat = capacity_sat, msg_payload_hex = msg_payload_hex, |
|
|
|
|
trusted = False) |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def msg_payload(self): |
|
|
|
|
return bytes.fromhex(self.msg_payload_hex) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Policy(Base): |
|
|
|
|
__tablename__ = 'policy' |
|
|
|
|
start_node = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) |
|
|
|
|
short_channel_id = Column(String(64), ForeignKey('channel_info.short_channel_id'), primary_key=True) |
|
|
|
|
cltv_expiry_delta = Column(Integer, nullable=False) |
|
|
|
|
htlc_minimum_msat = Column(Integer, nullable=False) |
|
|
|
|
htlc_maximum_msat = Column(Integer) |
|
|
|
|
fee_base_msat = Column(Integer, nullable=False) |
|
|
|
|
fee_proportional_millionths = Column(Integer, nullable=False) |
|
|
|
|
channel_flags = Column(Integer, nullable=False) |
|
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
return ChannelInfo( |
|
|
|
|
short_channel_id = channel_id, |
|
|
|
|
node1_id = node_id_1, |
|
|
|
|
node2_id = node_id_2, |
|
|
|
|
capacity_sat = capacity_sat, |
|
|
|
|
msg_payload = msg_payload, |
|
|
|
|
trusted = False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Policy(NamedTuple): |
|
|
|
|
key: bytes |
|
|
|
|
cltv_expiry_delta: int |
|
|
|
|
htlc_minimum_msat: int |
|
|
|
|
htlc_maximum_msat: int |
|
|
|
|
fee_base_msat: int |
|
|
|
|
fee_proportional_millionths: int |
|
|
|
|
channel_flags: int |
|
|
|
|
timestamp: int |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_msg(payload): |
|
|
|
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big") |
|
|
|
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big") |
|
|
|
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None |
|
|
|
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big") |
|
|
|
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big") |
|
|
|
|
channel_flags = int.from_bytes(payload['channel_flags'], "big") |
|
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
|
start_node = payload['start_node'].hex() |
|
|
|
|
short_channel_id = payload['short_channel_id'].hex() |
|
|
|
|
|
|
|
|
|
return Policy(start_node=start_node, |
|
|
|
|
short_channel_id=short_channel_id, |
|
|
|
|
cltv_expiry_delta=cltv_expiry_delta, |
|
|
|
|
htlc_minimum_msat=htlc_minimum_msat, |
|
|
|
|
fee_base_msat=fee_base_msat, |
|
|
|
|
fee_proportional_millionths=fee_proportional_millionths, |
|
|
|
|
channel_flags=channel_flags, |
|
|
|
|
timestamp=timestamp, |
|
|
|
|
htlc_maximum_msat=htlc_maximum_msat) |
|
|
|
|
return Policy( |
|
|
|
|
key = payload['short_channel_id'] + payload['start_node'], |
|
|
|
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"), |
|
|
|
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"), |
|
|
|
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None, |
|
|
|
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"), |
|
|
|
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"), |
|
|
|
|
channel_flags = int.from_bytes(payload['channel_flags'], "big"), |
|
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def is_disabled(self): |
|
|
|
|
return self.channel_flags & FLAG_DISABLE |
|
|
|
|
|
|
|
|
|
class NodeInfo(Base): |
|
|
|
|
__tablename__ = 'node_info' |
|
|
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
|
features = Column(Integer, nullable=False) |
|
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
alias = Column(String(64), nullable=False) |
|
|
|
|
@property |
|
|
|
|
def short_channel_id(self): |
|
|
|
|
return self.key[0:8] |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def start_node(self): |
|
|
|
|
return self.key[8:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NodeInfo(NamedTuple): |
|
|
|
|
node_id: bytes |
|
|
|
|
features: int |
|
|
|
|
timestamp: int |
|
|
|
|
alias: str |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_msg(payload): |
|
|
|
|
node_id = payload['node_id'].hex() |
|
|
|
|
node_id = payload['node_id'] |
|
|
|
|
features = int.from_bytes(payload['features'], "big") |
|
|
|
|
validate_features(features) |
|
|
|
|
addresses = NodeInfo.parse_addresses_field(payload['addresses']) |
|
|
|
|
alias = payload['alias'].rstrip(b'\x00').hex() |
|
|
|
|
alias = payload['alias'].rstrip(b'\x00') |
|
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
|
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ |
|
|
|
|
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] |
|
|
|
|
@ -193,110 +192,136 @@ class NodeInfo(Base):
|
|
|
|
|
break |
|
|
|
|
return addresses |
|
|
|
|
|
|
|
|
|
class Address(Base): |
|
|
|
|
|
|
|
|
|
class Address(NamedTuple): |
|
|
|
|
node_id: bytes |
|
|
|
|
host: str |
|
|
|
|
port: int |
|
|
|
|
last_connected_date: int |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelInfoBase(Base): |
|
|
|
|
__tablename__ = 'channel_info' |
|
|
|
|
short_channel_id = Column(String(64), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
|
node1_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
|
node2_id = Column(String(66), ForeignKey('node_info.node_id'), nullable=False) |
|
|
|
|
capacity_sat = Column(Integer) |
|
|
|
|
msg_payload = Column(String(1024), nullable=False) |
|
|
|
|
trusted = Column(Boolean, nullable=False) |
|
|
|
|
|
|
|
|
|
def to_nametuple(self): |
|
|
|
|
return ChannelInfo( |
|
|
|
|
short_channel_id=self.short_channel_id, |
|
|
|
|
node1_id=self.node1_id, |
|
|
|
|
node2_id=self.node2_id, |
|
|
|
|
capacity_sat=self.capacity_sat, |
|
|
|
|
msg_payload=self.msg_payload, |
|
|
|
|
trusted=self.trusted |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
class PolicyBase(Base): |
|
|
|
|
__tablename__ = 'policy' |
|
|
|
|
key = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
|
cltv_expiry_delta = Column(Integer, nullable=False) |
|
|
|
|
htlc_minimum_msat = Column(Integer, nullable=False) |
|
|
|
|
htlc_maximum_msat = Column(Integer) |
|
|
|
|
fee_base_msat = Column(Integer, nullable=False) |
|
|
|
|
fee_proportional_millionths = Column(Integer, nullable=False) |
|
|
|
|
channel_flags = Column(Integer, nullable=False) |
|
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
|
|
|
|
|
def to_nametuple(self): |
|
|
|
|
return Policy( |
|
|
|
|
key=self.key, |
|
|
|
|
cltv_expiry_delta=self.cltv_expiry_delta, |
|
|
|
|
htlc_minimum_msat=self.htlc_minimum_msat, |
|
|
|
|
htlc_maximum_msat=self.htlc_maximum_msat, |
|
|
|
|
fee_base_msat= self.fee_base_msat, |
|
|
|
|
fee_proportional_millionths = self.fee_proportional_millionths, |
|
|
|
|
channel_flags=self.channel_flags, |
|
|
|
|
timestamp=self.timestamp |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
class NodeInfoBase(Base): |
|
|
|
|
__tablename__ = 'node_info' |
|
|
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
|
features = Column(Integer, nullable=False) |
|
|
|
|
timestamp = Column(Integer, nullable=False) |
|
|
|
|
alias = Column(String(64), nullable=False) |
|
|
|
|
|
|
|
|
|
class AddressBase(Base): |
|
|
|
|
__tablename__ = 'address' |
|
|
|
|
node_id = Column(String(66), ForeignKey('node_info.node_id'), primary_key=True) |
|
|
|
|
node_id = Column(String(66), primary_key=True, sqlite_on_conflict_primary_key='REPLACE') |
|
|
|
|
host = Column(String(256), primary_key=True) |
|
|
|
|
port = Column(Integer, primary_key=True) |
|
|
|
|
last_connected_date = Column(Integer(), nullable=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelDB(SqlDB): |
|
|
|
|
|
|
|
|
|
NUM_MAX_RECENT_PEERS = 20 |
|
|
|
|
|
|
|
|
|
def __init__(self, network: 'Network'): |
|
|
|
|
path = os.path.join(get_headers_dir(network.config), 'channel_db') |
|
|
|
|
super().__init__(network, path, Base) |
|
|
|
|
super().__init__(network, path, Base, commit_interval=100) |
|
|
|
|
self.num_nodes = 0 |
|
|
|
|
self.num_channels = 0 |
|
|
|
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] |
|
|
|
|
self.ca_verifier = LNChannelVerifier(network, self) |
|
|
|
|
self.update_counts() |
|
|
|
|
# initialized in load_data |
|
|
|
|
self._channels = {} |
|
|
|
|
self._policies = {} |
|
|
|
|
self._nodes = {} |
|
|
|
|
self._addresses = defaultdict(set) |
|
|
|
|
self._channels_for_node = defaultdict(set) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def update_counts(self): |
|
|
|
|
self._update_counts() |
|
|
|
|
self.num_channels = len(self._channels) |
|
|
|
|
self.num_policies = len(self._policies) |
|
|
|
|
self.num_nodes = len(self._nodes) |
|
|
|
|
|
|
|
|
|
def _update_counts(self): |
|
|
|
|
self.num_channels = self.DBSession.query(ChannelInfo).count() |
|
|
|
|
self.num_policies = self.DBSession.query(Policy).count() |
|
|
|
|
self.num_nodes = self.DBSession.query(NodeInfo).count() |
|
|
|
|
def get_channel_ids(self): |
|
|
|
|
return set(self._channels.keys()) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def known_ids(self): |
|
|
|
|
known = self.DBSession.query(ChannelInfo.short_channel_id).all() |
|
|
|
|
return set(bfh(r.short_channel_id) for r in known) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def add_recent_peer(self, peer: LNPeerAddr): |
|
|
|
|
now = int(time.time()) |
|
|
|
|
node_id = peer.pubkey.hex() |
|
|
|
|
addr = self.DBSession.query(Address).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() |
|
|
|
|
node_id = peer.pubkey |
|
|
|
|
self._addresses[node_id].add((peer.host, peer.port, now)) |
|
|
|
|
self.save_address(node_id, peer, now) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def save_address(self, node_id, peer, now): |
|
|
|
|
addr = self.DBSession.query(AddressBase).filter_by(node_id=node_id, host=peer.host, port=peer.port).one_or_none() |
|
|
|
|
if addr: |
|
|
|
|
addr.last_connected_date = now |
|
|
|
|
else: |
|
|
|
|
addr = Address(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) |
|
|
|
|
addr = AddressBase(node_id=node_id, host=peer.host, port=peer.port, last_connected_date=now) |
|
|
|
|
self.DBSession.add(addr) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids_bytes): |
|
|
|
|
unshuffled = self.DBSession \ |
|
|
|
|
.query(NodeInfo) \ |
|
|
|
|
.filter(not_(NodeInfo.node_id.in_(x.hex() for x in node_ids_bytes))) \ |
|
|
|
|
.limit(200) \ |
|
|
|
|
.all() |
|
|
|
|
return random.sample(unshuffled, len(unshuffled)) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def nodes_get(self, node_id): |
|
|
|
|
return self.DBSession \ |
|
|
|
|
.query(NodeInfo) \ |
|
|
|
|
.filter_by(node_id = node_id.hex()) \ |
|
|
|
|
.one_or_none() |
|
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids): |
|
|
|
|
unshuffled = set(self._nodes.keys()) - node_ids |
|
|
|
|
return random.sample(unshuffled, min(200, len(unshuffled))) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
|
|
|
r = self.DBSession.query(Address).filter_by(node_id=node_id.hex()).order_by(Address.last_connected_date.desc()).all() |
|
|
|
|
r = self._addresses.get(node_id) |
|
|
|
|
if not r: |
|
|
|
|
return None |
|
|
|
|
addr = r[0] |
|
|
|
|
return LNPeerAddr(addr.host, addr.port, bytes.fromhex(addr.node_id)) |
|
|
|
|
addr = sorted(list(r), key=lambda x: x[2])[0] |
|
|
|
|
host, port, timestamp = addr |
|
|
|
|
return LNPeerAddr(host, port, node_id) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_recent_peers(self): |
|
|
|
|
r = self.DBSession.query(Address).filter(Address.last_connected_date.isnot(None)).order_by(Address.last_connected_date.desc()).limit(self.NUM_MAX_RECENT_PEERS).all() |
|
|
|
|
return [LNPeerAddr(x.host, x.port, bytes.fromhex(x.node_id)) for x in r] |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def missing_channel_announcements(self) -> Set[int]: |
|
|
|
|
expr = not_(Policy.short_channel_id.in_(self.DBSession.query(ChannelInfo.short_channel_id))) |
|
|
|
|
return set(x[0] for x in self.DBSession.query(Policy.short_channel_id).filter(expr).all()) |
|
|
|
|
r = [self.get_last_good_address(x) for x in self._addresses.keys()] |
|
|
|
|
r = r[-self.NUM_MAX_RECENT_PEERS:] |
|
|
|
|
return r |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def missing_channel_updates(self) -> Set[int]: |
|
|
|
|
expr = not_(ChannelInfo.short_channel_id.in_(self.DBSession.query(Policy.short_channel_id))) |
|
|
|
|
return set(x[0] for x in self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).all()) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def add_verified_channel_info(self, short_id, capacity): |
|
|
|
|
# called from lnchannelverifier |
|
|
|
|
channel_info = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_id.hex()).one_or_none() |
|
|
|
|
channel_info.trusted = True |
|
|
|
|
channel_info.capacity = capacity |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def on_channel_announcement(self, msg_payloads, trusted=True): |
|
|
|
|
def add_channel_announcement(self, msg_payloads, trusted=True): |
|
|
|
|
if type(msg_payloads) is dict: |
|
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
|
new_channels = {} |
|
|
|
|
added = 0 |
|
|
|
|
for msg in msg_payloads: |
|
|
|
|
short_channel_id = bh2u(msg['short_channel_id']) |
|
|
|
|
if self.DBSession.query(ChannelInfo).filter_by(short_channel_id=short_channel_id).count(): |
|
|
|
|
short_channel_id = msg['short_channel_id'] |
|
|
|
|
if short_channel_id in self._channels: |
|
|
|
|
continue |
|
|
|
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']: |
|
|
|
|
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) |
|
|
|
|
@ -306,24 +331,24 @@ class ChannelDB(SqlDB):
|
|
|
|
|
except UnknownEvenFeatureBits: |
|
|
|
|
self.logger.info("unknown feature bits") |
|
|
|
|
continue |
|
|
|
|
channel_info.trusted = trusted |
|
|
|
|
new_channels[short_channel_id] = channel_info |
|
|
|
|
#channel_info.trusted = trusted |
|
|
|
|
added += 1 |
|
|
|
|
self._channels[short_channel_id] = channel_info |
|
|
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
|
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
|
|
|
self.save_channel(channel_info) |
|
|
|
|
if not trusted: |
|
|
|
|
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, channel_info.msg_payload) |
|
|
|
|
for channel_info in new_channels.values(): |
|
|
|
|
self.DBSession.add(channel_info) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
self._update_counts() |
|
|
|
|
self.logger.debug('on_channel_announcement: %d/%d'%(len(new_channels), len(msg_payloads))) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_last_timestamp(self): |
|
|
|
|
return self._get_last_timestamp() |
|
|
|
|
self.update_counts() |
|
|
|
|
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads))) |
|
|
|
|
|
|
|
|
|
def _get_last_timestamp(self): |
|
|
|
|
from sqlalchemy.sql import func |
|
|
|
|
r = self.DBSession.query(func.max(Policy.timestamp).label('max_timestamp')).one() |
|
|
|
|
return r.max_timestamp or 0 |
|
|
|
|
|
|
|
|
|
#def add_verified_channel_info(self, short_id, capacity): |
|
|
|
|
# # called from lnchannelverifier |
|
|
|
|
# channel_info = self.DBSession.query(ChannelInfoBase).filter_by(short_channel_id = short_id).one_or_none() |
|
|
|
|
# channel_info.trusted = True |
|
|
|
|
# channel_info.capacity = capacity |
|
|
|
|
|
|
|
|
|
def print_change(self, old_policy, new_policy): |
|
|
|
|
# print what changed between policies |
|
|
|
|
@ -340,89 +365,74 @@ class ChannelDB(SqlDB):
|
|
|
|
|
if old_policy.channel_flags != new_policy.channel_flags: |
|
|
|
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_info_for_updates(self, payloads): |
|
|
|
|
short_channel_ids = [payload['short_channel_id'].hex() for payload in payloads] |
|
|
|
|
channel_infos_list = self.DBSession.query(ChannelInfo).filter(ChannelInfo.short_channel_id.in_(short_channel_ids)).all() |
|
|
|
|
channel_infos = {bfh(x.short_channel_id): x for x in channel_infos_list} |
|
|
|
|
return channel_infos |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_policies_for_updates(self, payloads): |
|
|
|
|
out = {} |
|
|
|
|
for payload in payloads: |
|
|
|
|
short_channel_id = payload['short_channel_id'].hex() |
|
|
|
|
start_node = payload['start_node'].hex() |
|
|
|
|
policy = self.DBSession.query(Policy).filter_by(short_channel_id=short_channel_id, start_node=start_node).one_or_none() |
|
|
|
|
if policy: |
|
|
|
|
out[short_channel_id+start_node] = policy |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
@profiler |
|
|
|
|
def filter_channel_updates(self, payloads, max_age=None): |
|
|
|
|
def add_channel_updates(self, payloads, max_age=None, verify=True): |
|
|
|
|
orphaned = [] # no channel announcement for channel update |
|
|
|
|
expired = [] # update older than two weeks |
|
|
|
|
deprecated = [] # update older than database entry |
|
|
|
|
good = {} # good updates |
|
|
|
|
good = [] # good updates |
|
|
|
|
to_delete = [] # database entries to delete |
|
|
|
|
# filter orphaned and expired first |
|
|
|
|
known = [] |
|
|
|
|
now = int(time.time()) |
|
|
|
|
channel_infos = self.get_info_for_updates(payloads) |
|
|
|
|
for payload in payloads: |
|
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
|
if max_age and now - timestamp > max_age: |
|
|
|
|
expired.append(short_channel_id) |
|
|
|
|
continue |
|
|
|
|
channel_info = channel_infos.get(short_channel_id) |
|
|
|
|
channel_info = self._channels.get(short_channel_id) |
|
|
|
|
if not channel_info: |
|
|
|
|
orphaned.append(short_channel_id) |
|
|
|
|
continue |
|
|
|
|
flags = int.from_bytes(payload['channel_flags'], 'big') |
|
|
|
|
direction = flags & FLAG_DIRECTION |
|
|
|
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
|
|
|
payload['start_node'] = bfh(start_node) |
|
|
|
|
payload['start_node'] = start_node |
|
|
|
|
known.append(payload) |
|
|
|
|
# compare updates to existing database entries |
|
|
|
|
old_policies = self.get_policies_for_updates(known) |
|
|
|
|
for payload in known: |
|
|
|
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
|
|
|
start_node = payload['start_node'] |
|
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
|
key = (short_channel_id+start_node).hex() |
|
|
|
|
old_policy = old_policies.get(key) |
|
|
|
|
if old_policy: |
|
|
|
|
if timestamp <= old_policy.timestamp: |
|
|
|
|
deprecated.append(short_channel_id) |
|
|
|
|
else: |
|
|
|
|
good[key] = payload |
|
|
|
|
to_delete.append(old_policy) |
|
|
|
|
else: |
|
|
|
|
good[key] = payload |
|
|
|
|
good = list(good.values()) |
|
|
|
|
key = (start_node, short_channel_id) |
|
|
|
|
old_policy = self._policies.get(key) |
|
|
|
|
if old_policy and timestamp <= old_policy.timestamp: |
|
|
|
|
deprecated.append(short_channel_id) |
|
|
|
|
continue |
|
|
|
|
good.append(payload) |
|
|
|
|
if verify: |
|
|
|
|
self.verify_channel_update(payload) |
|
|
|
|
policy = Policy.from_msg(payload) |
|
|
|
|
self._policies[key] = policy |
|
|
|
|
self.save_policy(policy) |
|
|
|
|
# |
|
|
|
|
self.update_counts() |
|
|
|
|
return orphaned, expired, deprecated, good, to_delete |
|
|
|
|
|
|
|
|
|
def add_channel_update(self, payload): |
|
|
|
|
orphaned, expired, deprecated, good, to_delete = self.filter_channel_updates([payload]) |
|
|
|
|
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) |
|
|
|
|
assert len(good) == 1 |
|
|
|
|
self.update_policies(good, to_delete) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def update_policies(self, to_add, to_delete): |
|
|
|
|
for policy in to_delete: |
|
|
|
|
self.DBSession.delete(policy) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
for payload in to_add: |
|
|
|
|
policy = Policy.from_msg(payload) |
|
|
|
|
self.DBSession.add(policy) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
self._update_counts() |
|
|
|
|
def save_policy(self, policy): |
|
|
|
|
self.DBSession.execute(PolicyBase.__table__.insert().values(policy)) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def on_node_announcement(self, msg_payloads): |
|
|
|
|
def delete_policy(self, short_channel_id, node_id): |
|
|
|
|
self.DBSession.execute(PolicyBase.__table__.delete().values(policy)) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def save_channel(self, channel_info): |
|
|
|
|
self.DBSession.execute(ChannelInfoBase.__table__.insert().values(channel_info)) |
|
|
|
|
|
|
|
|
|
def verify_channel_update(self, payload): |
|
|
|
|
short_channel_id = payload['short_channel_id'] |
|
|
|
|
if constants.net.rev_genesis_bytes() != payload['chain_hash']: |
|
|
|
|
raise Exception('wrong chain hash') |
|
|
|
|
if not verify_sig_for_channel_update(payload, payload['start_node']): |
|
|
|
|
raise BaseException('verify error') |
|
|
|
|
|
|
|
|
|
def add_node_announcement(self, msg_payloads): |
|
|
|
|
if type(msg_payloads) is dict: |
|
|
|
|
msg_payloads = [msg_payloads] |
|
|
|
|
old_addr = None |
|
|
|
|
@ -435,29 +445,35 @@ class ChannelDB(SqlDB):
|
|
|
|
|
continue |
|
|
|
|
node_id = node_info.node_id |
|
|
|
|
# Ignore node if it has no associated channel (DoS protection) |
|
|
|
|
# FIXME this is slow |
|
|
|
|
expr = or_(ChannelInfo.node1_id==node_id, ChannelInfo.node2_id==node_id) |
|
|
|
|
if len(self.DBSession.query(ChannelInfo.short_channel_id).filter(expr).limit(1).all()) == 0: |
|
|
|
|
if node_id not in self._channels_for_node: |
|
|
|
|
#self.logger.info('ignoring orphan node_announcement') |
|
|
|
|
continue |
|
|
|
|
node = self.DBSession.query(NodeInfo).filter_by(node_id=node_id).one_or_none() |
|
|
|
|
node = self._nodes.get(node_id) |
|
|
|
|
if node and node.timestamp >= node_info.timestamp: |
|
|
|
|
continue |
|
|
|
|
node = new_nodes.get(node_id) |
|
|
|
|
if node and node.timestamp >= node_info.timestamp: |
|
|
|
|
continue |
|
|
|
|
new_nodes[node_id] = node_info |
|
|
|
|
# save |
|
|
|
|
self._nodes[node_id] = node_info |
|
|
|
|
self.save_node(node_info) |
|
|
|
|
for addr in node_addresses: |
|
|
|
|
new_addresses[(addr.node_id,addr.host,addr.port)] = addr |
|
|
|
|
self._addresses[node_id].add((addr.host, addr.port, 0)) |
|
|
|
|
self.save_node_addresses(node_id, node_addresses) |
|
|
|
|
|
|
|
|
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) |
|
|
|
|
for node_info in new_nodes.values(): |
|
|
|
|
self.DBSession.add(node_info) |
|
|
|
|
for new_addr in new_addresses.values(): |
|
|
|
|
old_addr = self.DBSession.query(Address).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() |
|
|
|
|
self.update_counts() |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def save_node_addresses(self, node_if, node_addresses): |
|
|
|
|
for new_addr in node_addresses: |
|
|
|
|
old_addr = self.DBSession.query(AddressBase).filter_by(node_id=new_addr.node_id, host=new_addr.host, port=new_addr.port).one_or_none() |
|
|
|
|
if not old_addr: |
|
|
|
|
self.DBSession.add(new_addr) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
self._update_counts() |
|
|
|
|
self.DBSession.execute(AddressBase.__table__.insert().values(new_addr)) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def save_node(self, node_info): |
|
|
|
|
self.DBSession.execute(NodeInfoBase.__table__.insert().values(node_info)) |
|
|
|
|
|
|
|
|
|
def get_routing_policy_for_channel(self, start_node_id: bytes, |
|
|
|
|
short_channel_id: bytes) -> Optional[bytes]: |
|
|
|
|
@ -470,41 +486,28 @@ class ChannelDB(SqlDB):
|
|
|
|
|
return None |
|
|
|
|
return Policy.from_msg(msg) # won't actually be written to DB |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def get_old_policies(self, delta): |
|
|
|
|
timestamp = int(time.time()) - delta |
|
|
|
|
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) |
|
|
|
|
return old_policies.distinct().count() |
|
|
|
|
now = int(time.time()) |
|
|
|
|
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def prune_old_policies(self, delta): |
|
|
|
|
# note: delete queries are order sensitive |
|
|
|
|
timestamp = int(time.time()) - delta |
|
|
|
|
old_policies = self.DBSession.query(Policy.short_channel_id).filter(Policy.timestamp <= timestamp) |
|
|
|
|
delete_old_channels = ChannelInfo.__table__.delete().where(ChannelInfo.short_channel_id.in_(old_policies)) |
|
|
|
|
delete_old_policies = Policy.__table__.delete().where(Policy.timestamp <= timestamp) |
|
|
|
|
self.DBSession.execute(delete_old_channels) |
|
|
|
|
self.DBSession.execute(delete_old_policies) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
self._update_counts() |
|
|
|
|
l = self.get_old_policies(delta) |
|
|
|
|
for k in l: |
|
|
|
|
self._policies.pop(k) |
|
|
|
|
if l: |
|
|
|
|
self.logger.info(f'Deleting {len(l)} old policies') |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def get_orphaned_channels(self): |
|
|
|
|
subquery = self.DBSession.query(Policy.short_channel_id) |
|
|
|
|
orphaned = self.DBSession.query(ChannelInfo).filter(not_(ChannelInfo.short_channel_id.in_(subquery))) |
|
|
|
|
return orphaned.count() |
|
|
|
|
ids = set(x[1] for x in self._policies.keys()) |
|
|
|
|
return list(x for x in self._channels.keys() if x not in ids) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def prune_orphaned_channels(self): |
|
|
|
|
subquery = self.DBSession.query(Policy.short_channel_id) |
|
|
|
|
delete_orphaned = ChannelInfo.__table__.delete().where(not_(ChannelInfo.short_channel_id.in_(subquery))) |
|
|
|
|
self.DBSession.execute(delete_orphaned) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
self._update_counts() |
|
|
|
|
l = self.get_orphaned_channels() |
|
|
|
|
for k in l: |
|
|
|
|
self._channels.pop(k) |
|
|
|
|
self.update_counts() |
|
|
|
|
if l: |
|
|
|
|
self.logger.info(f'Deleting {len(l)} orphaned channels') |
|
|
|
|
|
|
|
|
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): |
|
|
|
|
if not verify_sig_for_channel_update(msg_payload, start_node_id): |
|
|
|
|
@ -513,67 +516,27 @@ class ChannelDB(SqlDB):
|
|
|
|
|
msg_payload['start_node'] = start_node_id |
|
|
|
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def remove_channel(self, short_channel_id): |
|
|
|
|
r = self.DBSession.query(ChannelInfo).filter_by(short_channel_id = short_channel_id.hex()).one_or_none() |
|
|
|
|
if not r: |
|
|
|
|
return |
|
|
|
|
self.DBSession.delete(r) |
|
|
|
|
self.DBSession.commit() |
|
|
|
|
|
|
|
|
|
def print_graph(self, full_ids=False): |
|
|
|
|
# used for debugging. |
|
|
|
|
# FIXME there is a race here - iterables could change size from another thread |
|
|
|
|
def other_node_id(node_id, channel_id): |
|
|
|
|
channel_info = self.get_channel_info(channel_id) |
|
|
|
|
if node_id == channel_info.node1_id: |
|
|
|
|
other = channel_info.node2_id |
|
|
|
|
else: |
|
|
|
|
other = channel_info.node1_id |
|
|
|
|
return other if full_ids else other[-4:] |
|
|
|
|
|
|
|
|
|
print_msg('nodes') |
|
|
|
|
for node in self.DBSession.query(NodeInfo).all(): |
|
|
|
|
print_msg(node) |
|
|
|
|
|
|
|
|
|
print_msg('channels') |
|
|
|
|
for channel_info in self.DBSession.query(ChannelInfo).all(): |
|
|
|
|
short_channel_id = channel_info.short_channel_id |
|
|
|
|
node1 = channel_info.node1_id |
|
|
|
|
node2 = channel_info.node2_id |
|
|
|
|
direction1 = self.get_policy_for_node(channel_info, node1) is not None |
|
|
|
|
direction2 = self.get_policy_for_node(channel_info, node2) is not None |
|
|
|
|
if direction1 and direction2: |
|
|
|
|
direction = 'both' |
|
|
|
|
elif direction1: |
|
|
|
|
direction = 'forward' |
|
|
|
|
elif direction2: |
|
|
|
|
direction = 'backward' |
|
|
|
|
else: |
|
|
|
|
direction = 'none' |
|
|
|
|
print_msg('{}: {}, {}, {}' |
|
|
|
|
.format(bh2u(short_channel_id), |
|
|
|
|
bh2u(node1) if full_ids else bh2u(node1[-4:]), |
|
|
|
|
bh2u(node2) if full_ids else bh2u(node2[-4:]), |
|
|
|
|
direction)) |
|
|
|
|
self._channels.pop(short_channel_id, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
def get_node_addresses(self, node_info): |
|
|
|
|
return self.DBSession.query(Address).join(NodeInfo).filter_by(node_id = node_info.node_id).all() |
|
|
|
|
def get_node_addresses(self, node_id): |
|
|
|
|
return self._addresses.get(node_id) |
|
|
|
|
|
|
|
|
|
@sql |
|
|
|
|
@profiler |
|
|
|
|
def load_data(self): |
|
|
|
|
r = self.DBSession.query(ChannelInfo).all() |
|
|
|
|
self._channels = dict([(bfh(x.short_channel_id), x) for x in r]) |
|
|
|
|
r = self.DBSession.query(Policy).filter_by().all() |
|
|
|
|
self._policies = dict([((bfh(x.start_node), bfh(x.short_channel_id)), x) for x in r]) |
|
|
|
|
self._channels_for_node = defaultdict(set) |
|
|
|
|
for x in self.DBSession.query(AddressBase).all(): |
|
|
|
|
self._addresses[x.node_id].add((str(x.host), int(x.port), int(x.last_connected_date or 0))) |
|
|
|
|
for x in self.DBSession.query(ChannelInfoBase).all(): |
|
|
|
|
self._channels[x.short_channel_id] = x.to_nametuple() |
|
|
|
|
for x in self.DBSession.query(PolicyBase).filter_by().all(): |
|
|
|
|
p = x.to_nametuple() |
|
|
|
|
self._policies[(p.start_node, p.short_channel_id)] = p |
|
|
|
|
for channel_info in self._channels.values(): |
|
|
|
|
self._channels_for_node[bfh(channel_info.node1_id)].add(bfh(channel_info.short_channel_id)) |
|
|
|
|
self._channels_for_node[bfh(channel_info.node2_id)].add(bfh(channel_info.short_channel_id)) |
|
|
|
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
|
|
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
|
|
|
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') |
|
|
|
|
self.update_counts() |
|
|
|
|
|
|
|
|
|
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: |
|
|
|
|
return self._policies.get((node_id, short_channel_id)) |
|
|
|
|
@ -584,6 +547,3 @@ class ChannelDB(SqlDB):
|
|
|
|
|
def get_channels_for_node(self, node_id) -> Set[bytes]: |
|
|
|
|
"""Returns the set of channels that have node_id as one of the endpoints.""" |
|
|
|
|
return self._channels_for_node.get(node_id) or set() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|