Browse Source

Merge pull request #7202 from bitromortac/2104-mpp-channel-splitting

MPP splitting algorithm: redesign and split within channels
master
ghost43 4 years ago committed by GitHub
parent
commit
ce44a03c24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 114
      electrum/lnworker.py
  2. 368
      electrum/mpp_split.py
  3. 30
      electrum/tests/test_lnpeer.py
  4. 71
      electrum/tests/test_mpp_split.py
  5. 11
      electrum/trampoline.py

114
electrum/lnworker.py

@ -8,7 +8,7 @@ from decimal import Decimal
import random import random
import time import time
from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING,
NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator) NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator, DefaultDict)
import threading import threading
import socket import socket
import aiohttp import aiohttp
@ -1159,7 +1159,7 @@ class LNWallet(LNWorker):
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
self.logs[payment_hash.hex()] = log = [] self.logs[payment_hash.hex()] = log = []
trampoline_fee_level = self.INITIAL_TRAMPOLINE_FEE_LEVEL trampoline_fee_levels = defaultdict(lambda: self.INITIAL_TRAMPOLINE_FEE_LEVEL) # type: DefaultDict[bytes, int]
use_two_trampolines = True # only used for pay to legacy use_two_trampolines = True # only used for pay to legacy
amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
@ -1179,7 +1179,7 @@ class LNWallet(LNWorker):
full_path=full_path, full_path=full_path,
payment_hash=payment_hash, payment_hash=payment_hash,
payment_secret=payment_secret, payment_secret=payment_secret,
trampoline_fee_level=trampoline_fee_level, trampoline_fee_levels=trampoline_fee_levels,
use_two_trampolines=use_two_trampolines, use_two_trampolines=use_two_trampolines,
fwd_trampoline_onion=fwd_trampoline_onion fwd_trampoline_onion=fwd_trampoline_onion
) )
@ -1222,11 +1222,12 @@ class LNWallet(LNWorker):
# if we get a channel update, we might retry the same route and amount # if we get a channel update, we might retry the same route and amount
route = htlc_log.route route = htlc_log.route
sender_idx = htlc_log.sender_idx sender_idx = htlc_log.sender_idx
erring_node_id = route[sender_idx].node_id
failure_msg = htlc_log.failure_msg failure_msg = htlc_log.failure_msg
code, data = failure_msg.code, failure_msg.data code, data = failure_msg.code, failure_msg.data
self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. " self.logger.info(f"UPDATE_FAIL_HTLC. code={repr(code)}. "
f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}") f"decoded_data={failure_msg.decode_data()}. data={data.hex()!r}")
self.logger.info(f"error reported by {bh2u(route[sender_idx].node_id)}") self.logger.info(f"error reported by {bh2u(erring_node_id)}")
if code == OnionFailureCode.MPP_TIMEOUT: if code == OnionFailureCode.MPP_TIMEOUT:
raise PaymentFailure(failure_msg.code_name()) raise PaymentFailure(failure_msg.code_name())
# trampoline # trampoline
@ -1238,7 +1239,7 @@ class LNWallet(LNWorker):
if code in (OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT, if code in (OnionFailureCode.TRAMPOLINE_FEE_INSUFFICIENT,
OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON): OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON):
# todo: parse the node parameters here (not returned by eclair yet) # todo: parse the node parameters here (not returned by eclair yet)
trampoline_fee_level += 1 trampoline_fee_levels[erring_node_id] += 1
continue continue
elif use_two_trampolines: elif use_two_trampolines:
use_two_trampolines = False use_two_trampolines = False
@ -1458,7 +1459,7 @@ class LNWallet(LNWorker):
invoice_features: int, invoice_features: int,
payment_hash, payment_hash,
payment_secret, payment_secret,
trampoline_fee_level: int, trampoline_fee_levels: DefaultDict[bytes, int],
use_two_trampolines: bool, use_two_trampolines: bool,
fwd_trampoline_onion = None, fwd_trampoline_onion = None,
full_path: LNPaymentPath = None) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]: full_path: LNPaymentPath = None) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]:
@ -1468,16 +1469,13 @@ class LNWallet(LNWorker):
We first try to conduct the payment over a single channel. If that fails We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment.""" and mpp is supported by the receiver, we will split the payment."""
# It could happen that the pathfinding uses a channel
# in the graph multiple times, meaning we could exhaust
# its capacity. This could be dealt with by temporarily
# iteratively blacklisting channels for this mpp attempt.
invoice_features = LnFeatures(invoice_features) invoice_features = LnFeatures(invoice_features)
trampoline_features = LnFeatures.VAR_ONION_OPT trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height() local_height = self.network.get_local_height()
my_active_channels = [chan for chan in self.channels.values() if my_active_channels = [chan for chan in self.channels.values() if
chan.is_active() and not chan.is_frozen_for_sending()] chan.is_active() and not chan.is_frozen_for_sending()]
try: try:
self.logger.info("trying single-part payment")
# try to send over a single channel # try to send over a single channel
if not self.channel_db: if not self.channel_db:
for chan in my_active_channels: for chan in my_active_channels:
@ -1502,7 +1500,7 @@ class LNWallet(LNWorker):
payment_hash=payment_hash, payment_hash=payment_hash,
payment_secret=payment_secret, payment_secret=payment_secret,
local_height=local_height, local_height=local_height,
trampoline_fee_level=trampoline_fee_level, trampoline_fee_levels=trampoline_fee_levels,
use_two_trampolines=use_two_trampolines) use_two_trampolines=use_two_trampolines)
trampoline_payment_secret = os.urandom(32) trampoline_payment_secret = os.urandom(32)
trampoline_total_msat = amount_with_fees trampoline_total_msat = amount_with_fees
@ -1522,7 +1520,7 @@ class LNWallet(LNWorker):
break break
else: else:
raise NoPathFound() raise NoPathFound()
else: else: # local single-part route computation
route = await run_in_thread( route = await run_in_thread(
partial( partial(
self.create_route_for_payment, self.create_route_for_payment,
@ -1536,58 +1534,72 @@ class LNWallet(LNWorker):
) )
) )
yield route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion yield route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
except NoPathFound: except NoPathFound: # fall back to payment splitting
self.logger.info("no path found, trying multi-part payment")
if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
raise raise
channels_with_funds = {(chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL)) channels_with_funds = {(chan.channel_id, chan.node_id): int(chan.available_to_spend(HTLCOwner.LOCAL))
for chan in my_active_channels} for chan in my_active_channels}
self.logger.info(f"channels_with_funds: {channels_with_funds}") self.logger.info(f"channels_with_funds: {channels_with_funds}")
if not self.channel_db:
# for trampoline mpp payments we have to restrict ourselves to pay # for trampoline mpp payments we have to restrict ourselves to pay
# to a single node due to some incompatibility in Eclair, see: # to a single node due to some incompatibility in Eclair, see:
# https://github.com/ACINQ/eclair/issues/1723 # https://github.com/ACINQ/eclair/issues/1723
use_singe_node = not self.channel_db and constants.net is constants.BitcoinMainnet use_singe_node = constants.net is constants.BitcoinMainnet
split_configurations = suggest_splits(amount_msat, channels_with_funds, single_node=use_singe_node) split_configurations = suggest_splits(
amount_msat,
channels_with_funds,
exclude_multinode_payments=use_singe_node,
exclude_single_part_payments=True,
# we don't split within a channel when sending to a trampoline node,
# the trampoline node will split for us
exclude_single_channel_splits=True,
)
self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations') self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
for s in split_configurations: for sc in split_configurations:
self.logger.info(f"trying split configuration: {s[0].values()} rating: {s[1]}")
try: try:
if not self.channel_db: self.logger.info(f"trying split configuration: {sc.config.values()} rating: {sc.rating}")
buckets = defaultdict(list) per_trampoline_channel_amounts = defaultdict(list)
for (chan_id, _), part_amount_msat in s[0].items(): # categorize by trampoline nodes for trampolin mpp construction
for (chan_id, _), part_amounts_msat in sc.config.items():
chan = self.channels[chan_id] chan = self.channels[chan_id]
if part_amount_msat: for part_amount_msat in part_amounts_msat:
buckets[chan.node_id].append((chan_id, part_amount_msat)) per_trampoline_channel_amounts[chan.node_id].append((chan_id, part_amount_msat))
for node_id, bucket in buckets.items(): # for each trampoline forwarder, construct mpp trampoline
bucket_amount_msat = sum([x[1] for x in bucket]) routes = []
trampoline_onion, bucket_amount_with_fees, bucket_cltv_delta = create_trampoline_route_and_onion( for trampoline_node_id, trampoline_parts in per_trampoline_channel_amounts.items():
amount_msat=bucket_amount_msat, per_trampoline_amount = sum([x[1] for x in trampoline_parts])
trampoline_onion, per_trampoline_amount_with_fees, per_trampoline_cltv_delta = create_trampoline_route_and_onion(
amount_msat=per_trampoline_amount,
total_msat=final_total_msat, total_msat=final_total_msat,
min_cltv_expiry=min_cltv_expiry, min_cltv_expiry=min_cltv_expiry,
my_pubkey=self.node_keypair.pubkey, my_pubkey=self.node_keypair.pubkey,
invoice_pubkey=invoice_pubkey, invoice_pubkey=invoice_pubkey,
invoice_features=invoice_features, invoice_features=invoice_features,
node_id=node_id, node_id=trampoline_node_id,
r_tags=r_tags, r_tags=r_tags,
payment_hash=payment_hash, payment_hash=payment_hash,
payment_secret=payment_secret, payment_secret=payment_secret,
local_height=local_height, local_height=local_height,
trampoline_fee_level=trampoline_fee_level, trampoline_fee_levels=trampoline_fee_levels,
use_two_trampolines=use_two_trampolines) use_two_trampolines=use_two_trampolines)
# node_features is only used to determine is_tlv # node_features is only used to determine is_tlv
bucket_payment_secret = os.urandom(32) per_trampoline_secret = os.urandom(32)
bucket_fees = bucket_amount_with_fees - bucket_amount_msat per_trampoline_fees = per_trampoline_amount_with_fees - per_trampoline_amount
self.logger.info(f'bucket_fees {bucket_fees}') self.logger.info(f'per trampoline fees: {per_trampoline_fees}')
for chan_id, part_amount_msat in bucket: for chan_id, part_amount_msat in trampoline_parts:
chan = self.channels[chan_id] chan = self.channels[chan_id]
margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat margin = chan.available_to_spend(LOCAL, strict=True) - part_amount_msat
delta_fee = min(bucket_fees, margin) delta_fee = min(per_trampoline_fees, margin)
# TODO: distribute trampoline fee over several channels?
part_amount_msat_with_fees = part_amount_msat + delta_fee part_amount_msat_with_fees = part_amount_msat + delta_fee
bucket_fees -= delta_fee per_trampoline_fees -= delta_fee
route = [ route = [
RouteEdge( RouteEdge(
start_node=self.node_keypair.pubkey, start_node=self.node_keypair.pubkey,
end_node=node_id, end_node=trampoline_node_id,
short_channel_id=chan.short_channel_id, short_channel_id=chan.short_channel_id,
fee_base_msat=0, fee_base_msat=0,
fee_proportional_millionths=0, fee_proportional_millionths=0,
@ -1595,14 +1607,32 @@ class LNWallet(LNWorker):
node_features=trampoline_features) node_features=trampoline_features)
] ]
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
yield route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion routes.append((route, part_amount_msat_with_fees, per_trampoline_amount_with_fees, part_amount_msat, per_trampoline_cltv_delta, per_trampoline_secret, trampoline_onion))
if bucket_fees != 0: if per_trampoline_fees != 0:
self.logger.info('not enough margin to pay trampoline fee') self.logger.info('not enough margin to pay trampoline fee')
raise NoPathFound() raise NoPathFound()
for route in routes:
yield route
return
except NoPathFound:
continue
else: else:
for (chan_id, _), part_amount_msat in s[0].items(): split_configurations = suggest_splits(
if part_amount_msat: amount_msat,
channels_with_funds,
exclude_single_part_payments=True,
)
# We atomically loop through a split configuration. If there was
# a failure to find a path for a single part, we give back control
# after exhausting the split configuration.
yielded_from_split_configuration = False
self.logger.info(f'suggest_split {amount_msat} returned {len(split_configurations)} configurations')
for sc in split_configurations:
self.logger.info(f"trying split configuration: {list(sc.config.values())} rating: {sc.rating}")
for (chan_id, _), part_amounts_msat in sc.config.items():
for part_amount_msat in part_amounts_msat:
channel = self.channels[chan_id] channel = self.channels[chan_id]
try:
route = await run_in_thread( route = await run_in_thread(
partial( partial(
self.create_route_for_payment, self.create_route_for_payment,
@ -1616,11 +1646,11 @@ class LNWallet(LNWorker):
) )
) )
yield route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion yield route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion
self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}") yielded_from_split_configuration = True
break
except NoPathFound: except NoPathFound:
continue continue
else: if yielded_from_split_configuration:
return
raise NoPathFound() raise NoPathFound()
@profiler @profiler

368
electrum/mpp_split.py

@ -1,259 +1,195 @@
import random import random
import math import math
from typing import List, Tuple, Optional, Sequence, Dict, TYPE_CHECKING from typing import List, Tuple, Dict, NamedTuple
from collections import defaultdict from collections import defaultdict
from .util import profiler
from .lnutil import NoPathFound from .lnutil import NoPathFound
PART_PENALTY = 1.0 # 1.0 results in avoiding splits PART_PENALTY = 1.0 # 1.0 results in avoiding splits
MIN_PART_MSAT = 10_000_000 # we don't want to split indefinitely MIN_PART_SIZE_MSAT = 10_000_000 # we don't want to split indefinitely
EXHAUST_DECAY_FRACTION = 10 # fraction of the local balance that should be reserved if possible EXHAUST_DECAY_FRACTION = 10 # fraction of the local balance that should be reserved if possible
RELATIVE_SPLIT_SPREAD = 0.3 # deviation from the mean when splitting amounts into parts
# these parameters determine the granularity of the newly suggested configurations
REDISTRIBUTION_FRACTION = 50
SPLIT_FRACTION = 50
# these parameters affect the computational work in the probabilistic algorithm # these parameters affect the computational work in the probabilistic algorithm
STARTING_CONFIGS = 50 CANDIDATES_PER_LEVEL = 20
CANDIDATES_PER_LEVEL = 10 MAX_PARTS = 5 # maximum number of parts for splitting
REDISTRIBUTE = 20
# maximum number of parts for splitting
MAX_PARTS = 5
# maps a channel (channel_id, node_id) to a list of amounts
SplitConfig = Dict[Tuple[bytes, bytes], List[int]]
# maps a channel (channel_id, node_id) to the funds it has available
ChannelsFundsInfo = Dict[Tuple[bytes, bytes], int]
def unique_hierarchy(hierarchy: Dict[int, List[Dict[Tuple[bytes, bytes], int]]]) -> Dict[int, List[Dict[Tuple[bytes, bytes], int]]]:
new_hierarchy = defaultdict(list)
for number_parts, configs in hierarchy.items():
unique_configs = set()
for config in configs:
# config dict can be out of order, so sort, otherwise not unique
unique_configs.add(tuple((c, config[c]) for c in sorted(config.keys())))
for unique_config in sorted(unique_configs):
new_hierarchy[number_parts].append(
{t[0]: t[1] for t in unique_config})
return new_hierarchy
class SplitConfigRating(NamedTuple):
config: SplitConfig
rating: float
def single_node_hierarchy(hierarchy: Dict[int, List[Dict[Tuple[bytes, bytes], int]]]) -> Dict[int, List[Dict[Tuple[bytes, bytes], int]]]:
new_hierarchy = defaultdict(list)
for number_parts, configs in hierarchy.items():
for config in configs:
# determine number of nodes in configuration
if number_nonzero_nodes(config) > 1:
continue
new_hierarchy[number_parts].append(config)
return new_hierarchy
def split_amount_normal(total_amount: int, num_parts: int) -> List[int]:
"""Splits an amount into about `num_parts` parts, where the parts are split
randomly (normally distributed around amount/num_parts with certain spread)."""
parts = []
avg_amount = total_amount / num_parts
# roughly reach total_amount
while total_amount - sum(parts) > avg_amount:
amount_to_add = int(abs(random.gauss(avg_amount, RELATIVE_SPLIT_SPREAD * avg_amount)))
if sum(parts) + amount_to_add < total_amount:
parts.append(amount_to_add)
# add what's missing
parts.append(total_amount - sum(parts))
return parts
def number_nonzero_parts(configuration: Dict[Tuple[bytes, bytes], int]) -> int:
return len([v for v in configuration.values() if v])
def number_parts(config: SplitConfig) -> int:
return sum([len(v) for v in config.values() if sum(v)])
def number_nonzero_nodes(configuration: Dict[Tuple[bytes, bytes], int]) -> int:
return len({nodeid for (_, nodeid), amount in configuration.items() if amount > 0})
def number_nonzero_channels(config: SplitConfig) -> int:
return len([v for v in config.values() if sum(v)])
def create_starting_split_hierarchy(amount_msat: int, channels_with_funds: Dict[Tuple[bytes, bytes], int]):
"""Distributes the amount to send to a single or more channels in several
ways (randomly)."""
# TODO: find all possible starting configurations deterministically
# could try all permutations
split_hierarchy = defaultdict(list) def number_nonzero_nodes(config: SplitConfig) -> int:
channels_order = list(channels_with_funds.keys()) # using a set comprehension
return len({nodeid for (_, nodeid), amounts in config.items() if sum(amounts)})
for _ in range(STARTING_CONFIGS):
# shuffle to have different starting points
random.shuffle(channels_order)
configuration = {} def total_config_amount(config: SplitConfig) -> int:
amount_added = 0 return sum([sum(c) for c in config.values()])
for c in channels_order:
s = channels_with_funds[c]
if amount_added == amount_msat:
configuration[c] = 0
else:
amount_to_add = amount_msat - amount_added
amt = min(s, amount_to_add)
configuration[c] = amt
amount_added += amt
if amount_added != amount_msat:
raise NoPathFound("Channels don't have enough sending capacity.")
split_hierarchy[number_nonzero_parts(configuration)].append(configuration)
return unique_hierarchy(split_hierarchy)
def balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
check = (
proposed_balance_to < MIN_PART_MSAT or
proposed_balance_to > channels_with_funds[channel_to] or
proposed_balance_from < MIN_PART_MSAT or
proposed_balance_from > channels_with_funds[channel_from]
)
return check
def propose_new_configuration(channels_with_funds: Dict[Tuple[bytes, bytes], int], configuration: Dict[Tuple[bytes, bytes], int],
amount_msat: int, preserve_number_parts=True) -> Dict[Tuple[bytes, bytes], int]:
"""Randomly alters a split configuration. If preserve_number_parts, the
configuration stays within the same class of number of splits."""
# there are three basic operations to reach different split configurations:
# redistribute, split, swap
def redistribute(config: dict):
# we redistribute the amount from a nonzero channel to a nonzero channel
redistribution_amount = amount_msat // REDISTRIBUTION_FRACTION
nonzero = [ck for ck, cv in config.items() if
cv >= redistribution_amount]
if len(nonzero) == 1: # we only have a single channel, so we can't redistribute
return config
channel_from = random.choice(nonzero)
channel_to = random.choice(nonzero)
if channel_from == channel_to:
return config
proposed_balance_from = config[channel_from] - redistribution_amount
proposed_balance_to = config[channel_to] + redistribution_amount
if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
return config
else:
config[channel_from] = proposed_balance_from
config[channel_to] = proposed_balance_to
assert sum([cv for cv in config.values()]) == amount_msat
return config
def split(config: dict):
# we split off a certain amount from a nonzero channel and put it into a
# zero channel
nonzero = [ck for ck, cv in config.items() if cv != 0]
zero = [ck for ck, cv in config.items() if cv == 0]
try:
channel_from = random.choice(nonzero)
channel_to = random.choice(zero)
except IndexError:
return config
delta = config[channel_from] // SPLIT_FRACTION
proposed_balance_from = config[channel_from] - delta
proposed_balance_to = config[channel_to] + delta
if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
return config
else:
config[channel_from] = proposed_balance_from
config[channel_to] = proposed_balance_to
assert sum([cv for cv in config.values()]) == amount_msat
return config
def swap(config: dict):
# we swap the amounts from a single channel with another channel
nonzero = [ck for ck, cv in config.items() if cv != 0]
all = list(config.keys())
channel_from = random.choice(nonzero)
channel_to = random.choice(all)
proposed_balance_to = config[channel_from]
proposed_balance_from = config[channel_to]
if balances_are_not_ok(proposed_balance_from, channel_from, proposed_balance_to, channel_to, channels_with_funds):
return config
else:
config[channel_to] = proposed_balance_to
config[channel_from] = proposed_balance_from
return config
initial_number_parts = number_nonzero_parts(configuration)
for _ in range(REDISTRIBUTE): def is_any_amount_smaller_than_min_part_size(config: SplitConfig) -> bool:
configuration = redistribute(configuration) smaller = False
if not preserve_number_parts and number_nonzero_parts( for amounts in config.values():
configuration) == initial_number_parts: if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
configuration = split(configuration) smaller |= True
configuration = swap(configuration) return smaller
return configuration
def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
unique_configs = set()
for config in configs:
# sort keys and values
config_sorted_values = {k: sorted(v) for k, v in config.items()}
config_sorted_keys = {k: config_sorted_values[k] for k in sorted(config_sorted_values.keys())}
hashable_config = tuple((c, tuple(sorted(config[c]))) for c in config_sorted_keys)
unique_configs.add(hashable_config)
unique_configs = [{c[0]: list(c[1]) for c in config} for config in unique_configs]
return unique_configs
@profiler
def suggest_splits(amount_msat: int, channels_with_funds: Dict[Tuple[bytes, bytes], int],
exclude_single_parts=True, single_node=False) \
-> Sequence[Tuple[Dict[Tuple[bytes, bytes], int], float]]:
"""Creates split configurations for a payment over channels. Single channel
payments are excluded by default. channels_with_funds is keyed by
(channelid, nodeid)."""
def rate_configuration(config: dict) -> float: def remove_multiple_nodes(configs: List[SplitConfig]) -> List[SplitConfig]:
"""Defines an objective function to rate a split configuration. return [config for config in configs if number_nonzero_nodes(config) == 1]
We calculate the normalized L2 norm for a split configuration and
def remove_single_part_configs(configs: List[SplitConfig]) -> List[SplitConfig]:
return [config for config in configs if number_parts(config) != 1]
def remove_single_channel_splits(configs: List[SplitConfig]) -> List[SplitConfig]:
filtered = []
for config in configs:
for v in config.values():
if len(v) > 1:
continue
filtered.append(config)
return filtered
def rate_config(
config: SplitConfig,
channels_with_funds: ChannelsFundsInfo) -> float:
"""Defines an objective function to rate a configuration.
We calculate the normalized L2 norm for a configuration and
add a part penalty for each nonzero amount. The consequence is that add a part penalty for each nonzero amount. The consequence is that
amounts that are equally distributed and have less parts are rated amounts that are equally distributed and have less parts are rated
lowest.""" lowest (best). A penalty depending on the total amount sent over a channel
F = 0 counteracts channel exhaustion."""
total_amount = sum([v for v in config.values()]) rating = 0
total_amount = total_config_amount(config)
for channel, amount in config.items(): for channel, amounts in config.items():
funds = channels_with_funds[channel] funds = channels_with_funds[channel]
if amount: if amounts:
F += amount * amount / (total_amount * total_amount) # a penalty to favor distribution of amounts for amount in amounts:
F += PART_PENALTY * PART_PENALTY # a penalty for each part rating += amount * amount / (total_amount * total_amount) # penalty to favor equal distribution of amounts
rating += PART_PENALTY * PART_PENALTY # penalty for each part
decay = funds / EXHAUST_DECAY_FRACTION decay = funds / EXHAUST_DECAY_FRACTION
F += math.exp((amount - funds) / decay) # a penalty for channel saturation rating += math.exp((sum(amounts) - funds) / decay) # penalty for channel exhaustion
return rating
return F
def rated_sorted_configurations(hierarchy: dict) -> Sequence[Tuple[Dict[Tuple[bytes, bytes], int], float]]: def suggest_splits(
"""Cleans up duplicate splittings, rates and sorts them according to amount_msat: int, channels_with_funds: ChannelsFundsInfo,
the rating. A lower rating is a better configuration.""" exclude_single_part_payments=False,
hierarchy = unique_hierarchy(hierarchy) exclude_multinode_payments=False,
rated_configs = [] exclude_single_channel_splits=False
for level, configs in hierarchy.items(): ) -> List[SplitConfigRating]:
for config in configs: """Breaks amount_msat into smaller pieces and distributes them over the
rated_configs.append((config, rate_configuration(config))) channels according to the funds they can send.
sorted_rated_configs = sorted(rated_configs, key=lambda c: c[1], reverse=False)
return sorted_rated_configs
# create initial guesses Individual channels may be assigned multiple parts. The split configurations
split_hierarchy = create_starting_split_hierarchy(amount_msat, channels_with_funds) are returned in sorted order, from best to worst rating.
# randomize initial guesses and generate splittings of different split Single part payments can be excluded, since they represent legacy payments.
# levels up to number of channels Split configurations that send via multiple nodes can be excluded as well.
for level in range(2, min(MAX_PARTS, len(channels_with_funds) + 1)): """
# generate a set of random configurations for each level
configs = []
channels_order = list(channels_with_funds.keys())
# generate multiple configurations to get more configurations (there is randomness in this loop)
for _ in range(CANDIDATES_PER_LEVEL): for _ in range(CANDIDATES_PER_LEVEL):
configurations = unique_hierarchy(split_hierarchy).get(level, None) # we want to have configurations with no splitting to many splittings
if configurations: # we have a splitting of the desired number of parts for target_parts in range(1, MAX_PARTS):
configuration = random.choice(configurations) config = defaultdict(list) # type: SplitConfig
# generate new splittings preserving the number of parts
configuration = propose_new_configuration( # randomly split amount into target_parts chunks
channels_with_funds, configuration, amount_msat, split_amounts = split_amount_normal(amount_msat, target_parts)
preserve_number_parts=True) # randomly distribute amounts over channels
for amount in split_amounts:
random.shuffle(channels_order)
# we check each channel and try to put the funds inside, break if we succeed
for c in channels_order:
if sum(config[c]) + amount <= channels_with_funds[c]:
config[c].append(amount)
break
# if we don't succeed to put the amount anywhere,
# we try to fill up channels and put the rest somewhere else
else: else:
# go one level lower and look for valid splittings, distribute_amount = amount
# try to go one level higher by splitting a single outgoing amount for c in channels_order:
configurations = unique_hierarchy(split_hierarchy).get(level - 1, None) funds_left = channels_with_funds[c] - sum(config[c])
if not configurations: # it would be good to not fill the full channel if possible
add_amount = min(funds_left, distribute_amount)
config[c].append(add_amount)
distribute_amount -= add_amount
if distribute_amount == 0:
break
if total_config_amount(config) != amount_msat:
raise NoPathFound('Cannot distribute payment over channels.')
if target_parts > 1 and is_any_amount_smaller_than_min_part_size(config):
continue continue
configuration = random.choice(configurations) assert total_config_amount(config) == amount_msat
# generate new splittings going one level higher in the number of parts configs.append(config)
configuration = propose_new_configuration(
channels_with_funds, configuration, amount_msat, configs = remove_duplicates(configs)
preserve_number_parts=False)
# we only take configurations that send via a single node (but there can be multiple parts)
# add the newly found configuration (doesn't matter if nothing changed) if exclude_multinode_payments:
split_hierarchy[number_nonzero_parts(configuration)].append(configuration) configs = remove_multiple_nodes(configs)
if exclude_single_parts: if exclude_single_part_payments:
# we only want to return configurations that have at least two parts configs = remove_single_part_configs(configs)
try:
del split_hierarchy[1] if exclude_single_channel_splits:
except: configs = remove_single_channel_splits(configs)
pass
rated_configs = [SplitConfigRating(
if single_node: config=c,
# we only take configurations that send to a single node rating=rate_config(c, channels_with_funds)
split_hierarchy = single_node_hierarchy(split_hierarchy) ) for c in configs]
rated_configs.sort(key=lambda x: x.rating)
return rated_sorted_configurations(split_hierarchy)
return rated_configs

30
electrum/tests/test_lnpeer.py

@ -205,7 +205,7 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]):
min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(),
r_tags=decoded_invoice.get_routing_info('r'), r_tags=decoded_invoice.get_routing_info('r'),
invoice_features=decoded_invoice.get_features(), invoice_features=decoded_invoice.get_features(),
trampoline_fee_level=0, trampoline_fee_levels=defaultdict[int],
use_two_trampolines=False, use_two_trampolines=False,
payment_hash=decoded_invoice.paymenthash, payment_hash=decoded_invoice.paymenthash,
payment_secret=decoded_invoice.payment_secret, payment_secret=decoded_invoice.payment_secret,
@ -888,15 +888,18 @@ class TestPeer(TestCaseForTestnet):
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f()) run(f())
def _run_mpp(self, graph, kwargs1, kwargs2): def _run_mpp(self, graph, fail_kwargs, success_kwargs):
"""Tests a multipart payment scenario for failing and successful cases."""
self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ab.balance(LOCAL))
self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL)) self.assertEqual(500_000_000_000, graph.chan_ac.balance(LOCAL))
amount_to_pay = 600_000_000_000 amount_to_pay = 600_000_000_000
peers = graph.all_peers() peers = graph.all_peers()
async def pay(attempts=1, async def pay(
attempts=1,
alice_uses_trampoline=False, alice_uses_trampoline=False,
bob_forwarding=True, bob_forwarding=True,
mpp_invoice=True): mpp_invoice=True
):
if mpp_invoice: if mpp_invoice:
graph.w_d.features |= LnFeatures.BASIC_MPP_OPT graph.w_d.features |= LnFeatures.BASIC_MPP_OPT
if not bob_forwarding: if not bob_forwarding:
@ -930,22 +933,22 @@ class TestPeer(TestCaseForTestnet):
await group.spawn(pay(**kwargs)) await group.spawn(pay(**kwargs))
with self.assertRaises(NoPathFound): with self.assertRaises(NoPathFound):
run(f(kwargs1)) run(f(fail_kwargs))
with self.assertRaises(PaymentDone): with self.assertRaises(PaymentDone):
run(f(kwargs2)) run(f(success_kwargs))
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_multipart_payment_with_timeout(self): def test_payment_multipart_with_timeout(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
self._run_mpp(graph, {'bob_forwarding':False}, {'bob_forwarding':True}) self._run_mpp(graph, {'bob_forwarding': False}, {'bob_forwarding': True})
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_multipart_payment(self): def test_payment_multipart(self):
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
self._run_mpp(graph, {'mpp_invoice':False}, {'mpp_invoice':True}) self._run_mpp(graph, {'mpp_invoice': False}, {'mpp_invoice': True})
@needs_test_with_all_chacha20_implementations @needs_test_with_all_chacha20_implementations
def test_multipart_payment_with_trampoline(self): def test_payment_multipart_trampoline(self):
# single attempt will fail with insufficient trampoline fee # single attempt will fail with insufficient trampoline fee
graph = self.prepare_chans_and_peers_in_square() graph = self.prepare_chans_and_peers_in_square()
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = { electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {
@ -953,7 +956,10 @@ class TestPeer(TestCaseForTestnet):
graph.w_c.name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.w_c.node_keypair.pubkey), graph.w_c.name: LNPeerAddr(host="127.0.0.1", port=9735, pubkey=graph.w_c.node_keypair.pubkey),
} }
try: try:
self._run_mpp(graph, {'alice_uses_trampoline':True, 'attempts':1}, {'alice_uses_trampoline':True, 'attempts':30}) self._run_mpp(
graph,
{'alice_uses_trampoline': True, 'attempts': 1},
{'alice_uses_trampoline': True, 'attempts': 30})
finally: finally:
electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {} electrum.trampoline._TRAMPOLINE_NODES_UNITTESTS = {}

71
electrum/tests/test_mpp_split.py

@ -28,43 +28,58 @@ class TestMppSplit(ElectrumTestCase):
def test_suggest_splits(self): def test_suggest_splits(self):
with self.subTest(msg="do a payment with the maximal amount spendable over a single channel"): with self.subTest(msg="do a payment with the maximal amount spendable over a single channel"):
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_parts=True) splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=True)
self.assertEqual({(0, 0): 660_000_000, (1, 1): 340_000_000, (2, 0): 0, (3, 2): 0}, splits[0][0]) self.assertEqual({
(0, 0): [671_020_676],
(1, 1): [328_979_324],
(2, 0): [],
(3, 2): []},
splits[0].config
)
with self.subTest(msg="do a payment with a larger amount than what is supported by a single channel"): with self.subTest(msg="do a payment with a larger amount than what is supported by a single channel"):
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds, exclude_single_parts=True) splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(2, mpp_split.number_nonzero_parts(splits[0][0])) self.assertEqual(2, mpp_split.number_parts(splits[0].config))
with self.subTest(msg="do a payment with the maximal amount spendable over all channels"): with self.subTest(msg="do a payment with the maximal amount spendable over all channels"):
splits = mpp_split.suggest_splits(sum(self.channels_with_funds.values()), self.channels_with_funds, exclude_single_parts=True) splits = mpp_split.suggest_splits(
self.assertEqual({(0, 0): 1_000_000_000, (1, 1): 500_000_000, (2, 0): 302_000_000, (3, 2): 101_000_000}, splits[0][0]) sum(self.channels_with_funds.values()), self.channels_with_funds, exclude_single_part_payments=True)
self.assertEqual({
(0, 0): [1_000_000_000],
(1, 1): [500_000_000],
(2, 0): [302_000_000],
(3, 2): [101_000_000]},
splits[0].config
)
with self.subTest(msg="do a payment with the amount supported by all channels"): with self.subTest(msg="do a payment with the amount supported by all channels"):
splits = mpp_split.suggest_splits(101_000_000, self.channels_with_funds, exclude_single_parts=False) splits = mpp_split.suggest_splits(101_000_000, self.channels_with_funds, exclude_single_part_payments=False)
for s in splits[:4]: for split in splits[:3]:
self.assertEqual(1, mpp_split.number_nonzero_parts(s[0])) self.assertEqual(1, mpp_split.number_nonzero_channels(split.config))
# due to exhaustion of the smallest channel, the algorithm favors
# a splitting of the parts into two
self.assertEqual(2, mpp_split.number_parts(splits[4].config))
def test_send_to_single_node(self): def test_send_to_single_node(self):
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_parts=True, single_node=True) splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=False, exclude_multinode_payments=True)
self.assertEqual({(0, 0): 738_000_000, (1, 1): 0, (2, 0): 262_000_000, (3, 2): 0}, splits[0][0])
for split in splits: for split in splits:
assert mpp_split.number_nonzero_nodes(split[0]) == 1 assert mpp_split.number_nonzero_nodes(split.config) == 1
def test_saturation(self): def test_saturation(self):
"""Split configurations which spend the full amount in a channel should be avoided.""" """Split configurations which spend the full amount in a channel should be avoided."""
channels_with_funds = {(0, 0): 159_799_733_076, (1, 1): 499_986_152_000} channels_with_funds = {(0, 0): 159_799_733_076, (1, 1): 499_986_152_000}
splits = mpp_split.suggest_splits(600_000_000_000, channels_with_funds, exclude_single_parts=True) splits = mpp_split.suggest_splits(600_000_000_000, channels_with_funds, exclude_single_part_payments=True)
uses_full_amount = False uses_full_amount = False
for c, a in splits[0][0].items(): for c, a in splits[0].config.items():
if a == channels_with_funds[c]: if a == channels_with_funds[c]:
uses_full_amount |= True uses_full_amount |= True
self.assertFalse(uses_full_amount) self.assertFalse(uses_full_amount)
def test_payment_below_min_part_size(self): def test_payment_below_min_part_size(self):
amount = mpp_split.MIN_PART_MSAT // 2 amount = mpp_split.MIN_PART_SIZE_MSAT // 2
splits = mpp_split.suggest_splits(amount, self.channels_with_funds, exclude_single_parts=False) splits = mpp_split.suggest_splits(amount, self.channels_with_funds, exclude_single_part_payments=False)
# we only get four configurations that end up spending the full amount # we only get four configurations that end up spending the full amount
# in a single channel # in a single channel
self.assertEqual(4, len(splits)) self.assertEqual(4, len(splits))
@ -77,25 +92,37 @@ class TestMppSplit(ElectrumTestCase):
with self.subTest(msg="split payments with intermediate part penalty"): with self.subTest(msg="split payments with intermediate part penalty"):
mpp_split.PART_PENALTY = 1.0 mpp_split.PART_PENALTY = 1.0
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(2, mpp_split.number_nonzero_parts(splits[0][0])) self.assertEqual(2, mpp_split.number_parts(splits[0].config))
with self.subTest(msg="split payments with intermediate part penalty"): with self.subTest(msg="split payments with intermediate part penalty"):
mpp_split.PART_PENALTY = 0.3 mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(3, mpp_split.number_nonzero_parts(splits[0][0])) self.assertEqual(4, mpp_split.number_parts(splits[0].config))
with self.subTest(msg="split payments with no part penalty"): with self.subTest(msg="split payments with no part penalty"):
mpp_split.PART_PENALTY = 0.0 mpp_split.PART_PENALTY = 0.0
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(4, mpp_split.number_nonzero_parts(splits[0][0])) self.assertEqual(5, mpp_split.number_parts(splits[0].config))
def test_suggest_splits_single_channel(self): def test_suggest_splits_single_channel(self):
channels_with_funds = { channels_with_funds = {
0: 1_000_000_000, (0, 0): 1_000_000_000,
} }
with self.subTest(msg="do a payment with the maximal amount spendable on a single channel"): with self.subTest(msg="do a payment with the maximal amount spendable on a single channel"):
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_parts=False) splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual({0: 1_000_000_000}, splits[0][0]) self.assertEqual({(0, 0): [1_000_000_000]}, splits[0].config)
with self.subTest(msg="test sending an amount greater than what we have available"): with self.subTest(msg="test sending an amount greater than what we have available"):
self.assertRaises(NoPathFound, mpp_split.suggest_splits, *(1_100_000_000, channels_with_funds)) self.assertRaises(NoPathFound, mpp_split.suggest_splits, *(1_100_000_000, channels_with_funds))
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
mpp_split.PART_PENALTY = 0.5
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(2, len(splits[0].config[(0, 0)]))
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(3, len(splits[0].config[(0, 0)]))
with self.subTest(msg="exclude all single channel splits"):
mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_channel_splits=True)
self.assertEqual(1, len(splits[0].config[(0, 0)]))

11
electrum/trampoline.py

@ -2,7 +2,7 @@ import os
import bitstring import bitstring
import random import random
from typing import Mapping from typing import Mapping, DefaultDict
from .logging import get_logger, Logger from .logging import get_logger, Logger
from .lnutil import LnFeatures from .lnutil import LnFeatures
@ -108,7 +108,7 @@ def create_trampoline_route(
my_pubkey: bytes, my_pubkey: bytes,
trampoline_node_id: bytes, # the first trampoline in the path; which we are directly connected to trampoline_node_id: bytes, # the first trampoline in the path; which we are directly connected to
r_tags, r_tags,
trampoline_fee_level: int, trampoline_fee_levels: DefaultDict[bytes, int],
use_two_trampolines: bool) -> LNPaymentRoute: use_two_trampolines: bool) -> LNPaymentRoute:
# figure out whether we can use end-to-end trampoline, or fallback to pay-to-legacy # figure out whether we can use end-to-end trampoline, or fallback to pay-to-legacy
@ -141,7 +141,8 @@ def create_trampoline_route(
if pubkey == TRAMPOLINE_NODES_MAINNET['ACINQ'].pubkey: if pubkey == TRAMPOLINE_NODES_MAINNET['ACINQ'].pubkey:
is_legacy = True is_legacy = True
use_two_trampolines = False use_two_trampolines = False
# fee level. the same fee is used for all trampolines # fee level
trampoline_fee_level = trampoline_fee_levels[trampoline_node_id]
if trampoline_fee_level < len(TRAMPOLINE_FEES): if trampoline_fee_level < len(TRAMPOLINE_FEES):
params = TRAMPOLINE_FEES[trampoline_fee_level] params = TRAMPOLINE_FEES[trampoline_fee_level]
else: else:
@ -270,7 +271,7 @@ def create_trampoline_route_and_onion(
payment_hash, payment_hash,
payment_secret, payment_secret,
local_height:int, local_height:int,
trampoline_fee_level: int, trampoline_fee_levels: DefaultDict[bytes, int],
use_two_trampolines: bool): use_two_trampolines: bool):
# create route for the trampoline_onion # create route for the trampoline_onion
trampoline_route = create_trampoline_route( trampoline_route = create_trampoline_route(
@ -281,7 +282,7 @@ def create_trampoline_route_and_onion(
invoice_features=invoice_features, invoice_features=invoice_features,
trampoline_node_id=node_id, trampoline_node_id=node_id,
r_tags=r_tags, r_tags=r_tags,
trampoline_fee_level=trampoline_fee_level, trampoline_fee_levels=trampoline_fee_levels,
use_two_trampolines=use_two_trampolines) use_two_trampolines=use_two_trampolines)
# compute onion and fees # compute onion and fees
final_cltv = local_height + min_cltv_expiry final_cltv = local_height + min_cltv_expiry

Loading…
Cancel
Save