Browse Source

mpp_split: make SplitConfig a subclass of dict, not just a type-hint

master
SomberNight 2 years ago
parent
commit
39ef1cd4b9
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/lnworker.py
  2. 68
      electrum/mpp_split.py
  3. 60
      electrum/tests/test_mpp_split.py

2
electrum/lnworker.py

@ -1841,7 +1841,7 @@ class LNWallet(LNWorker):
)
for sc in split_configurations:
is_multichan_mpp = len(sc.config.items()) > 1
is_mpp = sum(len(x) for x in list(sc.config.values())) > 1
is_mpp = sc.config.number_parts() > 1
if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT):
continue
if not is_mpp and self.config.TEST_FORCE_MPP:

68
electrum/mpp_split.py

@ -15,12 +15,33 @@ CANDIDATES_PER_LEVEL = 20
MAX_PARTS = 5 # maximum number of parts for splitting
# maps a channel (channel_id, node_id) to a list of amounts
SplitConfig = Dict[Tuple[bytes, bytes], List[int]]
# maps a channel (channel_id, node_id) to the funds it has available
ChannelsFundsInfo = Dict[Tuple[bytes, bytes], int]
class SplitConfig(dict, Dict[Tuple[bytes, bytes], List[int]]):
"""maps a channel (channel_id, node_id) to a list of amounts"""
def number_parts(self) -> int:
return sum([len(v) for v in self.values() if sum(v)])
def number_nonzero_channels(self) -> int:
return len([v for v in self.values() if sum(v)])
def number_nonzero_nodes(self) -> int:
# using a set comprehension
return len({nodeid for (_, nodeid), amounts in self.items() if sum(amounts)})
def total_config_amount(self) -> int:
return sum([sum(c) for c in self.values()])
def is_any_amount_smaller_than_min_part_size(self) -> bool:
smaller = False
for amounts in self.values():
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
smaller |= True
return smaller
class SplitConfigRating(NamedTuple):
config: SplitConfig
rating: float
@ -41,31 +62,6 @@ def split_amount_normal(total_amount: int, num_parts: int) -> List[int]:
return parts
def number_parts(config: SplitConfig) -> int:
return sum([len(v) for v in config.values() if sum(v)])
def number_nonzero_channels(config: SplitConfig) -> int:
return len([v for v in config.values() if sum(v)])
def number_nonzero_nodes(config: SplitConfig) -> int:
# using a set comprehension
return len({nodeid for (_, nodeid), amounts in config.items() if sum(amounts)})
def total_config_amount(config: SplitConfig) -> int:
return sum([sum(c) for c in config.values()])
def is_any_amount_smaller_than_min_part_size(config: SplitConfig) -> bool:
smaller = False
for amounts in config.values():
if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]):
smaller |= True
return smaller
def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
unique_configs = set()
for config in configs:
@ -74,16 +70,16 @@ def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]:
config_sorted_keys = {k: config_sorted_values[k] for k in sorted(config_sorted_values.keys())}
hashable_config = tuple((c, tuple(sorted(config[c]))) for c in config_sorted_keys)
unique_configs.add(hashable_config)
unique_configs = [{c[0]: list(c[1]) for c in config} for config in unique_configs]
unique_configs = [SplitConfig({c[0]: list(c[1]) for c in config}) for config in unique_configs]
return unique_configs
def remove_multiple_nodes(configs: List[SplitConfig]) -> List[SplitConfig]:
return [config for config in configs if number_nonzero_nodes(config) == 1]
return [config for config in configs if config.number_nonzero_nodes() == 1]
def remove_single_part_configs(configs: List[SplitConfig]) -> List[SplitConfig]:
return [config for config in configs if number_parts(config) != 1]
return [config for config in configs if config.number_parts() != 1]
def remove_single_channel_splits(configs: List[SplitConfig]) -> List[SplitConfig]:
@ -107,7 +103,7 @@ def rate_config(
lowest (best). A penalty depending on the total amount sent over a channel
counteracts channel exhaustion."""
rating = 0
total_amount = total_config_amount(config)
total_amount = config.total_config_amount()
for channel, amounts in config.items():
funds = channels_with_funds[channel]
@ -143,7 +139,7 @@ def suggest_splits(
for _ in range(CANDIDATES_PER_LEVEL):
# we want to have configurations with no splitting to many splittings
for target_parts in range(1, MAX_PARTS):
config = defaultdict(list) # type: SplitConfig
config = SplitConfig()
# randomly split amount into target_parts chunks
split_amounts = split_amount_normal(amount_msat, target_parts)
@ -152,6 +148,8 @@ def suggest_splits(
random.shuffle(channels_order)
# we check each channel and try to put the funds inside, break if we succeed
for c in channels_order:
if c not in config:
config[c] = []
if sum(config[c]) + amount <= channels_with_funds[c]:
config[c].append(amount)
break
@ -167,11 +165,11 @@ def suggest_splits(
distribute_amount -= add_amount
if distribute_amount == 0:
break
if total_config_amount(config) != amount_msat:
if config.total_config_amount() != amount_msat:
raise NoPathFound('Cannot distribute payment over channels.')
if target_parts > 1 and is_any_amount_smaller_than_min_part_size(config):
if target_parts > 1 and config.is_any_amount_smaller_than_min_part_size():
continue
assert total_config_amount(config) == amount_msat
assert config.total_config_amount() == amount_msat
configs.append(config)
configs = remove_duplicates(configs)

60
electrum/tests/test_mpp_split.py

@ -15,10 +15,10 @@ class TestMppSplit(ElectrumTestCase):
random.seed(0)
# key tuple denotes (channel_id, node_id)
self.channels_with_funds = {
(0, 0): 1_000_000_000,
(1, 1): 500_000_000,
(2, 0): 302_000_000,
(3, 2): 101_000_000,
(b"0", b"0"): 1_000_000_000,
(b"1", b"1"): 500_000_000,
(b"2", b"0"): 302_000_000,
(b"3", b"2"): 101_000_000,
}
def tearDown(self):
@ -30,52 +30,52 @@ class TestMppSplit(ElectrumTestCase):
with self.subTest(msg="do a payment with the maximal amount spendable over a single channel"):
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=True)
self.assertEqual({
(0, 0): [671_020_676],
(1, 1): [328_979_324],
(2, 0): [],
(3, 2): []},
(b"0", b"0"): [671_020_676],
(b"1", b"1"): [328_979_324],
(b"2", b"0"): [],
(b"3", b"2"): []},
splits[0].config
)
with self.subTest(msg="payment amount that does not require to be split"):
splits = mpp_split.suggest_splits(50_000_000, self.channels_with_funds, exclude_single_part_payments=False)
self.assertEqual({(0, 0): [50_000_000]}, splits[0].config)
self.assertEqual({(1, 1): [50_000_000]}, splits[1].config)
self.assertEqual({(2, 0): [50_000_000]}, splits[2].config)
self.assertEqual({(3, 2): [50_000_000]}, splits[3].config)
self.assertEqual(2, mpp_split.number_parts(splits[4].config))
self.assertEqual({(b"0", b"0"): [50_000_000]}, splits[0].config)
self.assertEqual({(b"1", b"1"): [50_000_000]}, splits[1].config)
self.assertEqual({(b"2", b"0"): [50_000_000]}, splits[2].config)
self.assertEqual({(b"3", b"2"): [50_000_000]}, splits[3].config)
self.assertEqual(2, splits[4].config.number_parts())
with self.subTest(msg="do a payment with a larger amount than what is supported by a single channel"):
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(2, mpp_split.number_parts(splits[0].config))
self.assertEqual(2, splits[0].config.number_parts())
with self.subTest(msg="do a payment with the maximal amount spendable over all channels"):
splits = mpp_split.suggest_splits(
sum(self.channels_with_funds.values()), self.channels_with_funds, exclude_single_part_payments=True)
self.assertEqual({
(0, 0): [1_000_000_000],
(1, 1): [500_000_000],
(2, 0): [302_000_000],
(3, 2): [101_000_000]},
(b"0", b"0"): [1_000_000_000],
(b"1", b"1"): [500_000_000],
(b"2", b"0"): [302_000_000],
(b"3", b"2"): [101_000_000]},
splits[0].config
)
with self.subTest(msg="do a payment with the amount supported by all channels"):
splits = mpp_split.suggest_splits(101_000_000, self.channels_with_funds, exclude_single_part_payments=False)
for split in splits[:3]:
self.assertEqual(1, mpp_split.number_nonzero_channels(split.config))
self.assertEqual(1, split.config.number_nonzero_channels())
# due to exhaustion of the smallest channel, the algorithm favors
# a splitting of the parts into two
self.assertEqual(2, mpp_split.number_parts(splits[4].config))
self.assertEqual(2, splits[4].config.number_parts())
def test_send_to_single_node(self):
splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=False, exclude_multinode_payments=True)
for split in splits:
assert mpp_split.number_nonzero_nodes(split.config) == 1
assert split.config.number_nonzero_nodes() == 1
def test_saturation(self):
"""Split configurations which spend the full amount in a channel should be avoided."""
channels_with_funds = {(0, 0): 159_799_733_076, (1, 1): 499_986_152_000}
channels_with_funds = {(b"0", b"0"): 159_799_733_076, (b"1", b"1"): 499_986_152_000}
splits = mpp_split.suggest_splits(600_000_000_000, channels_with_funds, exclude_single_part_payments=True)
uses_full_amount = False
@ -100,37 +100,37 @@ class TestMppSplit(ElectrumTestCase):
with self.subTest(msg="split payments with intermediate part penalty"):
mpp_split.PART_PENALTY = 1.0
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(2, mpp_split.number_parts(splits[0].config))
self.assertEqual(2, splits[0].config.number_parts())
with self.subTest(msg="split payments with intermediate part penalty"):
mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(4, mpp_split.number_parts(splits[0].config))
self.assertEqual(4, splits[0].config.number_parts())
with self.subTest(msg="split payments with no part penalty"):
mpp_split.PART_PENALTY = 0.0
splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds)
self.assertEqual(5, mpp_split.number_parts(splits[0].config))
self.assertEqual(5, splits[0].config.number_parts())
def test_suggest_splits_single_channel(self):
channels_with_funds = {
(0, 0): 1_000_000_000,
(b"0", b"0"): 1_000_000_000,
}
with self.subTest(msg="do a payment with the maximal amount spendable on a single channel"):
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual({(0, 0): [1_000_000_000]}, splits[0].config)
self.assertEqual({(b"0", b"0"): [1_000_000_000]}, splits[0].config)
with self.subTest(msg="test sending an amount greater than what we have available"):
self.assertRaises(NoPathFound, mpp_split.suggest_splits, *(1_100_000_000, channels_with_funds))
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
mpp_split.PART_PENALTY = 0.5
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(2, len(splits[0].config[(0, 0)]))
self.assertEqual(2, len(splits[0].config[(b"0", b"0")]))
with self.subTest(msg="test sending a large amount over a single channel in chunks"):
mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False)
self.assertEqual(3, len(splits[0].config[(0, 0)]))
self.assertEqual(3, len(splits[0].config[(b"0", b"0")]))
with self.subTest(msg="exclude all single channel splits"):
mpp_split.PART_PENALTY = 0.3
splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_channel_splits=True)
self.assertEqual(1, len(splits[0].config[(0, 0)]))
self.assertEqual(1, len(splits[0].config[(b"0", b"0")]))

Loading…
Cancel
Save