You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
539 lines
21 KiB
539 lines
21 KiB
# -*- coding: utf-8 -*- |
|
# |
|
# Electrum - lightweight Bitcoin client |
|
# Copyright (C) 2018 The Electrum developers |
|
# |
|
# Permission is hereby granted, free of charge, to any person |
|
# obtaining a copy of this software and associated documentation files |
|
# (the "Software"), to deal in the Software without restriction, |
|
# including without limitation the rights to use, copy, modify, merge, |
|
# publish, distribute, sublicense, and/or sell copies of the Software, |
|
# and to permit persons to whom the Software is furnished to do so, |
|
# subject to the following conditions: |
|
# |
|
# The above copyright notice and this permission notice shall be |
|
# included in all copies or substantial portions of the Software. |
|
# |
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, |
|
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
|
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND |
|
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS |
|
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN |
|
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
# SOFTWARE. |
|
|
|
from datetime import datetime |
|
import time |
|
import random |
|
import queue |
|
import os |
|
import json |
|
import threading |
|
import concurrent |
|
from collections import defaultdict |
|
from typing import Sequence, List, Tuple, Optional, Dict, NamedTuple, TYPE_CHECKING, Set |
|
import binascii |
|
import base64 |
|
|
|
|
|
from .sql_db import SqlDB, sql |
|
from . import constants |
|
from .util import bh2u, profiler, get_headers_dir, bfh, is_ip_address, list_enabled_bits, print_msg, chunks |
|
from .logging import Logger |
|
from .storage import JsonDB |
|
from .lnverifier import LNChannelVerifier, verify_sig_for_channel_update |
|
from .crypto import sha256d |
|
from . import ecc |
|
from .lnutil import (LN_GLOBAL_FEATURES_KNOWN_SET, LNPeerAddr, NUM_MAX_EDGES_IN_PAYMENT_PATH, |
|
NotFoundChanAnnouncementForUpdate) |
|
from .lnverifier import verify_sig_for_channel_update |
|
from .lnmsg import encode_msg |
|
|
|
if TYPE_CHECKING: |
|
from .lnchannel import Channel |
|
from .network import Network |
|
|
|
class UnknownEvenFeatureBits(Exception): pass |
|
|
|
def validate_features(features : int): |
|
enabled_features = list_enabled_bits(features) |
|
for fbit in enabled_features: |
|
if (1 << fbit) not in LN_GLOBAL_FEATURES_KNOWN_SET and fbit % 2 == 0: |
|
raise UnknownEvenFeatureBits() |
|
|
|
|
|
FLAG_DISABLE = 1 << 1 |
|
FLAG_DIRECTION = 1 << 0 |
|
|
|
class ChannelInfo(NamedTuple): |
|
short_channel_id: bytes |
|
node1_id: bytes |
|
node2_id: bytes |
|
capacity_sat: int |
|
|
|
@staticmethod |
|
def from_msg(payload): |
|
features = int.from_bytes(payload['features'], 'big') |
|
validate_features(features) |
|
channel_id = payload['short_channel_id'] |
|
node_id_1 = payload['node_id_1'] |
|
node_id_2 = payload['node_id_2'] |
|
assert list(sorted([node_id_1, node_id_2])) == [node_id_1, node_id_2] |
|
capacity_sat = None |
|
return ChannelInfo( |
|
short_channel_id = channel_id, |
|
node1_id = node_id_1, |
|
node2_id = node_id_2, |
|
capacity_sat = capacity_sat) |
|
|
|
|
|
class Policy(NamedTuple): |
|
key: bytes |
|
cltv_expiry_delta: int |
|
htlc_minimum_msat: int |
|
htlc_maximum_msat: int |
|
fee_base_msat: int |
|
fee_proportional_millionths: int |
|
channel_flags: int |
|
timestamp: int |
|
|
|
@staticmethod |
|
def from_msg(payload): |
|
return Policy( |
|
key = payload['short_channel_id'] + payload['start_node'], |
|
cltv_expiry_delta = int.from_bytes(payload['cltv_expiry_delta'], "big"), |
|
htlc_minimum_msat = int.from_bytes(payload['htlc_minimum_msat'], "big"), |
|
htlc_maximum_msat = int.from_bytes(payload['htlc_maximum_msat'], "big") if 'htlc_maximum_msat' in payload else None, |
|
fee_base_msat = int.from_bytes(payload['fee_base_msat'], "big"), |
|
fee_proportional_millionths = int.from_bytes(payload['fee_proportional_millionths'], "big"), |
|
channel_flags = int.from_bytes(payload['channel_flags'], "big"), |
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
) |
|
|
|
def is_disabled(self): |
|
return self.channel_flags & FLAG_DISABLE |
|
|
|
@property |
|
def short_channel_id(self): |
|
return self.key[0:8] |
|
|
|
@property |
|
def start_node(self): |
|
return self.key[8:] |
|
|
|
|
|
|
|
class NodeInfo(NamedTuple): |
|
node_id: bytes |
|
features: int |
|
timestamp: int |
|
alias: str |
|
|
|
@staticmethod |
|
def from_msg(payload): |
|
node_id = payload['node_id'] |
|
features = int.from_bytes(payload['features'], "big") |
|
validate_features(features) |
|
addresses = NodeInfo.parse_addresses_field(payload['addresses']) |
|
alias = payload['alias'].rstrip(b'\x00') |
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
return NodeInfo(node_id=node_id, features=features, timestamp=timestamp, alias=alias), [ |
|
Address(host=host, port=port, node_id=node_id, last_connected_date=None) for host, port in addresses] |
|
|
|
@staticmethod |
|
def parse_addresses_field(addresses_field): |
|
buf = addresses_field |
|
def read(n): |
|
nonlocal buf |
|
data, buf = buf[0:n], buf[n:] |
|
return data |
|
addresses = [] |
|
while buf: |
|
atype = ord(read(1)) |
|
if atype == 0: |
|
pass |
|
elif atype == 1: # IPv4 |
|
ipv4_addr = '.'.join(map(lambda x: '%d' % x, read(4))) |
|
port = int.from_bytes(read(2), 'big') |
|
if is_ip_address(ipv4_addr) and port != 0: |
|
addresses.append((ipv4_addr, port)) |
|
elif atype == 2: # IPv6 |
|
ipv6_addr = b':'.join([binascii.hexlify(read(2)) for i in range(8)]) |
|
ipv6_addr = ipv6_addr.decode('ascii') |
|
port = int.from_bytes(read(2), 'big') |
|
if is_ip_address(ipv6_addr) and port != 0: |
|
addresses.append((ipv6_addr, port)) |
|
elif atype == 3: # onion v2 |
|
host = base64.b32encode(read(10)) + b'.onion' |
|
host = host.decode('ascii').lower() |
|
port = int.from_bytes(read(2), 'big') |
|
addresses.append((host, port)) |
|
elif atype == 4: # onion v3 |
|
host = base64.b32encode(read(35)) + b'.onion' |
|
host = host.decode('ascii').lower() |
|
port = int.from_bytes(read(2), 'big') |
|
addresses.append((host, port)) |
|
else: |
|
# unknown address type |
|
# we don't know how long it is -> have to escape |
|
# if there are other addresses we could have parsed later, they are lost. |
|
break |
|
return addresses |
|
|
|
|
|
class Address(NamedTuple): |
|
node_id: bytes |
|
host: str |
|
port: int |
|
last_connected_date: int |
|
|
|
create_channel_info = """ |
|
CREATE TABLE IF NOT EXISTS channel_info ( |
|
short_channel_id VARCHAR(64), |
|
node1_id VARCHAR(66), |
|
node2_id VARCHAR(66), |
|
capacity_sat INTEGER, |
|
PRIMARY KEY(short_channel_id) |
|
)""" |
|
|
|
create_policy = """ |
|
CREATE TABLE IF NOT EXISTS policy ( |
|
key VARCHAR(66), |
|
cltv_expiry_delta INTEGER NOT NULL, |
|
htlc_minimum_msat INTEGER NOT NULL, |
|
htlc_maximum_msat INTEGER, |
|
fee_base_msat INTEGER NOT NULL, |
|
fee_proportional_millionths INTEGER NOT NULL, |
|
channel_flags INTEGER NOT NULL, |
|
timestamp INTEGER NOT NULL, |
|
PRIMARY KEY(key) |
|
)""" |
|
|
|
create_address = """ |
|
CREATE TABLE IF NOT EXISTS address ( |
|
node_id VARCHAR(66), |
|
host STRING(256), |
|
port INTEGER NOT NULL, |
|
timestamp INTEGER, |
|
PRIMARY KEY(node_id, host, port) |
|
)""" |
|
|
|
create_node_info = """ |
|
CREATE TABLE IF NOT EXISTS node_info ( |
|
node_id VARCHAR(66), |
|
features INTEGER NOT NULL, |
|
timestamp INTEGER NOT NULL, |
|
alias STRING(64), |
|
PRIMARY KEY(node_id) |
|
)""" |
|
|
|
|
|
class ChannelDB(SqlDB): |
|
|
|
NUM_MAX_RECENT_PEERS = 20 |
|
|
|
def __init__(self, network: 'Network'): |
|
path = os.path.join(get_headers_dir(network.config), 'channel_db') |
|
super().__init__(network, path, commit_interval=100) |
|
self.num_nodes = 0 |
|
self.num_channels = 0 |
|
self._channel_updates_for_private_channels = {} # type: Dict[Tuple[bytes, bytes], dict] |
|
self.ca_verifier = LNChannelVerifier(network, self) |
|
# initialized in load_data |
|
self._channels = {} |
|
self._policies = {} |
|
self._nodes = {} |
|
self._addresses = defaultdict(set) |
|
self._channels_for_node = defaultdict(set) |
|
|
|
def update_counts(self): |
|
self.num_channels = len(self._channels) |
|
self.num_policies = len(self._policies) |
|
self.num_nodes = len(self._nodes) |
|
|
|
def get_channel_ids(self): |
|
return set(self._channels.keys()) |
|
|
|
def add_recent_peer(self, peer: LNPeerAddr): |
|
now = int(time.time()) |
|
node_id = peer.pubkey |
|
self._addresses[node_id].add((peer.host, peer.port, now)) |
|
self.save_node_address(node_id, peer, now) |
|
|
|
def get_200_randomly_sorted_nodes_not_in(self, node_ids): |
|
unshuffled = set(self._nodes.keys()) - node_ids |
|
return random.sample(unshuffled, min(200, len(unshuffled))) |
|
|
|
def get_last_good_address(self, node_id) -> Optional[LNPeerAddr]: |
|
r = self._addresses.get(node_id) |
|
if not r: |
|
return None |
|
addr = sorted(list(r), key=lambda x: x[2])[0] |
|
host, port, timestamp = addr |
|
return LNPeerAddr(host, port, node_id) |
|
|
|
def get_recent_peers(self): |
|
r = [self.get_last_good_address(x) for x in self._addresses.keys()] |
|
r = r[-self.NUM_MAX_RECENT_PEERS:] |
|
return r |
|
|
|
def add_channel_announcement(self, msg_payloads, trusted=True): |
|
if type(msg_payloads) is dict: |
|
msg_payloads = [msg_payloads] |
|
added = 0 |
|
for msg in msg_payloads: |
|
short_channel_id = msg['short_channel_id'] |
|
if short_channel_id in self._channels: |
|
continue |
|
if constants.net.rev_genesis_bytes() != msg['chain_hash']: |
|
self.logger.info("ChanAnn has unexpected chain_hash {}".format(bh2u(msg['chain_hash']))) |
|
continue |
|
try: |
|
channel_info = ChannelInfo.from_msg(msg) |
|
except UnknownEvenFeatureBits: |
|
self.logger.info("unknown feature bits") |
|
continue |
|
added += 1 |
|
self._channels[short_channel_id] = channel_info |
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
self.save_channel(channel_info) |
|
if not trusted: |
|
self.ca_verifier.add_new_channel_info(channel_info.short_channel_id, msg) |
|
|
|
self.update_counts() |
|
self.logger.debug('add_channel_announcement: %d/%d'%(added, len(msg_payloads))) |
|
|
|
def print_change(self, old_policy, new_policy): |
|
# print what changed between policies |
|
if old_policy.cltv_expiry_delta != new_policy.cltv_expiry_delta: |
|
self.logger.info(f'cltv_expiry_delta: {old_policy.cltv_expiry_delta} -> {new_policy.cltv_expiry_delta}') |
|
if old_policy.htlc_minimum_msat != new_policy.htlc_minimum_msat: |
|
self.logger.info(f'htlc_minimum_msat: {old_policy.htlc_minimum_msat} -> {new_policy.htlc_minimum_msat}') |
|
if old_policy.htlc_maximum_msat != new_policy.htlc_maximum_msat: |
|
self.logger.info(f'htlc_maximum_msat: {old_policy.htlc_maximum_msat} -> {new_policy.htlc_maximum_msat}') |
|
if old_policy.fee_base_msat != new_policy.fee_base_msat: |
|
self.logger.info(f'fee_base_msat: {old_policy.fee_base_msat} -> {new_policy.fee_base_msat}') |
|
if old_policy.fee_proportional_millionths != new_policy.fee_proportional_millionths: |
|
self.logger.info(f'fee_proportional_millionths: {old_policy.fee_proportional_millionths} -> {new_policy.fee_proportional_millionths}') |
|
if old_policy.channel_flags != new_policy.channel_flags: |
|
self.logger.info(f'channel_flags: {old_policy.channel_flags} -> {new_policy.channel_flags}') |
|
|
|
def add_channel_updates(self, payloads, max_age=None, verify=True): |
|
orphaned = [] # no channel announcement for channel update |
|
expired = [] # update older than two weeks |
|
deprecated = [] # update older than database entry |
|
good = [] # good updates |
|
to_delete = [] # database entries to delete |
|
# filter orphaned and expired first |
|
known = [] |
|
now = int(time.time()) |
|
for payload in payloads: |
|
short_channel_id = payload['short_channel_id'] |
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
if max_age and now - timestamp > max_age: |
|
expired.append(short_channel_id) |
|
continue |
|
channel_info = self._channels.get(short_channel_id) |
|
if not channel_info: |
|
orphaned.append(short_channel_id) |
|
continue |
|
flags = int.from_bytes(payload['channel_flags'], 'big') |
|
direction = flags & FLAG_DIRECTION |
|
start_node = channel_info.node1_id if direction == 0 else channel_info.node2_id |
|
payload['start_node'] = start_node |
|
known.append(payload) |
|
# compare updates to existing database entries |
|
for payload in known: |
|
timestamp = int.from_bytes(payload['timestamp'], "big") |
|
start_node = payload['start_node'] |
|
short_channel_id = payload['short_channel_id'] |
|
key = (start_node, short_channel_id) |
|
old_policy = self._policies.get(key) |
|
if old_policy and timestamp <= old_policy.timestamp: |
|
deprecated.append(short_channel_id) |
|
continue |
|
good.append(payload) |
|
if verify: |
|
self.verify_channel_update(payload) |
|
policy = Policy.from_msg(payload) |
|
self._policies[key] = policy |
|
self.save_policy(policy) |
|
# |
|
self.update_counts() |
|
return orphaned, expired, deprecated, good, to_delete |
|
|
|
def add_channel_update(self, payload): |
|
orphaned, expired, deprecated, good, to_delete = self.add_channel_updates([payload], verify=False) |
|
assert len(good) == 1 |
|
|
|
def create_database(self): |
|
c = self.conn.cursor() |
|
c.execute(create_node_info) |
|
c.execute(create_address) |
|
c.execute(create_policy) |
|
c.execute(create_channel_info) |
|
self.conn.commit() |
|
|
|
@sql |
|
def save_policy(self, policy): |
|
c = self.conn.cursor() |
|
c.execute("""REPLACE INTO policy (key, cltv_expiry_delta, htlc_minimum_msat, htlc_maximum_msat, fee_base_msat, fee_proportional_millionths, channel_flags, timestamp) VALUES (?,?,?,?,?,?, ?, ?)""", list(policy)) |
|
|
|
@sql |
|
def delete_policy(self, node_id, short_channel_id): |
|
key = short_channel_id + node_id |
|
c = self.conn.cursor() |
|
c.execute("""DELETE FROM policy WHERE key=?""", (key,)) |
|
|
|
@sql |
|
def save_channel(self, channel_info): |
|
c = self.conn.cursor() |
|
c.execute("REPLACE INTO channel_info (short_channel_id, node1_id, node2_id, capacity_sat) VALUES (?,?,?,?)", list(channel_info)) |
|
|
|
@sql |
|
def delete_channel(self, short_channel_id): |
|
c = self.conn.cursor() |
|
c.execute("""DELETE FROM channel_info WHERE short_channel_id=?""", (short_channel_id,)) |
|
|
|
@sql |
|
def save_node(self, node_info): |
|
c = self.conn.cursor() |
|
c.execute("REPLACE INTO node_info (node_id, features, timestamp, alias) VALUES (?,?,?,?)", list(node_info)) |
|
|
|
@sql |
|
def save_node_address(self, node_id, peer, now): |
|
c = self.conn.cursor() |
|
c.execute("REPLACE INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (node_id, peer.host, peer.port, now)) |
|
|
|
@sql |
|
def save_node_addresses(self, node_id, node_addresses): |
|
c = self.conn.cursor() |
|
for addr in node_addresses: |
|
c.execute("SELECT * FROM address WHERE node_id=? AND host=? AND port=?", (addr.node_id, addr.host, addr.port)) |
|
r = c.fetchall() |
|
if r == []: |
|
c.execute("INSERT INTO address (node_id, host, port, timestamp) VALUES (?,?,?,?)", (addr.node_id, addr.host, addr.port, 0)) |
|
|
|
def verify_channel_update(self, payload): |
|
short_channel_id = payload['short_channel_id'] |
|
if constants.net.rev_genesis_bytes() != payload['chain_hash']: |
|
raise Exception('wrong chain hash') |
|
if not verify_sig_for_channel_update(payload, payload['start_node']): |
|
raise BaseException('verify error') |
|
|
|
def add_node_announcement(self, msg_payloads): |
|
if type(msg_payloads) is dict: |
|
msg_payloads = [msg_payloads] |
|
old_addr = None |
|
new_nodes = {} |
|
for msg_payload in msg_payloads: |
|
try: |
|
node_info, node_addresses = NodeInfo.from_msg(msg_payload) |
|
except UnknownEvenFeatureBits: |
|
continue |
|
node_id = node_info.node_id |
|
# Ignore node if it has no associated channel (DoS protection) |
|
if node_id not in self._channels_for_node: |
|
#self.logger.info('ignoring orphan node_announcement') |
|
continue |
|
node = self._nodes.get(node_id) |
|
if node and node.timestamp >= node_info.timestamp: |
|
continue |
|
node = new_nodes.get(node_id) |
|
if node and node.timestamp >= node_info.timestamp: |
|
continue |
|
# save |
|
self._nodes[node_id] = node_info |
|
self.save_node(node_info) |
|
for addr in node_addresses: |
|
self._addresses[node_id].add((addr.host, addr.port, 0)) |
|
self.save_node_addresses(node_id, node_addresses) |
|
|
|
self.logger.debug("on_node_announcement: %d/%d"%(len(new_nodes), len(msg_payloads))) |
|
self.update_counts() |
|
|
|
def get_routing_policy_for_channel(self, start_node_id: bytes, |
|
short_channel_id: bytes) -> Optional[bytes]: |
|
if not start_node_id or not short_channel_id: return None |
|
channel_info = self.get_channel_info(short_channel_id) |
|
if channel_info is not None: |
|
return self.get_policy_for_node(short_channel_id, start_node_id) |
|
msg = self._channel_updates_for_private_channels.get((start_node_id, short_channel_id)) |
|
if not msg: |
|
return None |
|
return Policy.from_msg(msg) # won't actually be written to DB |
|
|
|
def get_old_policies(self, delta): |
|
now = int(time.time()) |
|
return list(k for k, v in list(self._policies.items()) if v.timestamp <= now - delta) |
|
|
|
def prune_old_policies(self, delta): |
|
l = self.get_old_policies(delta) |
|
for k in l: |
|
self._policies.pop(k) |
|
self.delete_policy(*k) |
|
if l: |
|
self.logger.info(f'Deleting {len(l)} old policies') |
|
|
|
def get_orphaned_channels(self): |
|
ids = set(x[1] for x in self._policies.keys()) |
|
return list(x for x in self._channels.keys() if x not in ids) |
|
|
|
def prune_orphaned_channels(self): |
|
l = self.get_orphaned_channels() |
|
for short_channel_id in l: |
|
self.remove_channel(short_channel_id) |
|
self.delete_channel(short_channel_id) |
|
self.update_counts() |
|
if l: |
|
self.logger.info(f'Deleting {len(l)} orphaned channels') |
|
|
|
def add_channel_update_for_private_channel(self, msg_payload: dict, start_node_id: bytes): |
|
if not verify_sig_for_channel_update(msg_payload, start_node_id): |
|
return # ignore |
|
short_channel_id = msg_payload['short_channel_id'] |
|
msg_payload['start_node'] = start_node_id |
|
self._channel_updates_for_private_channels[(start_node_id, short_channel_id)] = msg_payload |
|
|
|
def remove_channel(self, short_channel_id): |
|
channel_info = self._channels.pop(short_channel_id, None) |
|
if channel_info: |
|
self._channels_for_node[channel_info.node1_id].remove(channel_info.short_channel_id) |
|
self._channels_for_node[channel_info.node2_id].remove(channel_info.short_channel_id) |
|
|
|
def get_node_addresses(self, node_id): |
|
return self._addresses.get(node_id) |
|
|
|
@sql |
|
@profiler |
|
def load_data(self): |
|
c = self.conn.cursor() |
|
c.execute("""SELECT * FROM address""") |
|
for x in c: |
|
node_id, host, port, timestamp = x |
|
self._addresses[node_id].add((str(host), int(port), int(timestamp or 0))) |
|
c.execute("""SELECT * FROM channel_info""") |
|
for x in c: |
|
ci = ChannelInfo(*x) |
|
self._channels[ci.short_channel_id] = ci |
|
c.execute("""SELECT * FROM policy""") |
|
for x in c: |
|
p = Policy(*x) |
|
self._policies[(p.start_node, p.short_channel_id)] = p |
|
for channel_info in self._channels.values(): |
|
self._channels_for_node[channel_info.node1_id].add(channel_info.short_channel_id) |
|
self._channels_for_node[channel_info.node2_id].add(channel_info.short_channel_id) |
|
self.logger.info(f'load data {len(self._channels)} {len(self._policies)} {len(self._channels_for_node)}') |
|
self.update_counts() |
|
|
|
def get_policy_for_node(self, short_channel_id: bytes, node_id: bytes) -> Optional['Policy']: |
|
return self._policies.get((node_id, short_channel_id)) |
|
|
|
def get_channel_info(self, channel_id: bytes): |
|
return self._channels.get(channel_id) |
|
|
|
def get_channels_for_node(self, node_id) -> Set[bytes]: |
|
"""Returns the set of channels that have node_id as one of the endpoints.""" |
|
return self._channels_for_node.get(node_id) or set()
|
|
|