diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 9c6c87efd..61773aae6 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1841,7 +1841,7 @@ class LNWallet(LNWorker): ) for sc in split_configurations: is_multichan_mpp = len(sc.config.items()) > 1 - is_mpp = sum(len(x) for x in list(sc.config.values())) > 1 + is_mpp = sc.config.number_parts() > 1 if is_mpp and not paysession.invoice_features.supports(LnFeatures.BASIC_MPP_OPT): continue if not is_mpp and self.config.TEST_FORCE_MPP: diff --git a/electrum/mpp_split.py b/electrum/mpp_split.py index c3eec558f..520d7f0a1 100644 --- a/electrum/mpp_split.py +++ b/electrum/mpp_split.py @@ -15,12 +15,33 @@ CANDIDATES_PER_LEVEL = 20 MAX_PARTS = 5 # maximum number of parts for splitting -# maps a channel (channel_id, node_id) to a list of amounts -SplitConfig = Dict[Tuple[bytes, bytes], List[int]] # maps a channel (channel_id, node_id) to the funds it has available ChannelsFundsInfo = Dict[Tuple[bytes, bytes], int] +class SplitConfig(dict, Dict[Tuple[bytes, bytes], List[int]]): + """maps a channel (channel_id, node_id) to a list of amounts""" + def number_parts(self) -> int: + return sum([len(v) for v in self.values() if sum(v)]) + + def number_nonzero_channels(self) -> int: + return len([v for v in self.values() if sum(v)]) + + def number_nonzero_nodes(self) -> int: + # using a set comprehension + return len({nodeid for (_, nodeid), amounts in self.items() if sum(amounts)}) + + def total_config_amount(self) -> int: + return sum([sum(c) for c in self.values()]) + + def is_any_amount_smaller_than_min_part_size(self) -> bool: + smaller = False + for amounts in self.values(): + if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]): + smaller |= True + return smaller + + class SplitConfigRating(NamedTuple): config: SplitConfig rating: float @@ -41,31 +62,6 @@ def split_amount_normal(total_amount: int, num_parts: int) -> List[int]: return parts -def number_parts(config: SplitConfig) -> int: - return sum([len(v) for v in config.values() if sum(v)]) - - -def number_nonzero_channels(config: SplitConfig) -> int: - return len([v for v in config.values() if sum(v)]) - - -def number_nonzero_nodes(config: SplitConfig) -> int: - # using a set comprehension - return len({nodeid for (_, nodeid), amounts in config.items() if sum(amounts)}) - - -def total_config_amount(config: SplitConfig) -> int: - return sum([sum(c) for c in config.values()]) - - -def is_any_amount_smaller_than_min_part_size(config: SplitConfig) -> bool: - smaller = False - for amounts in config.values(): - if any([amount < MIN_PART_SIZE_MSAT for amount in amounts]): - smaller |= True - return smaller - - def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]: unique_configs = set() for config in configs: @@ -74,16 +70,16 @@ def remove_duplicates(configs: List[SplitConfig]) -> List[SplitConfig]: config_sorted_keys = {k: config_sorted_values[k] for k in sorted(config_sorted_values.keys())} hashable_config = tuple((c, tuple(sorted(config[c]))) for c in config_sorted_keys) unique_configs.add(hashable_config) - unique_configs = [{c[0]: list(c[1]) for c in config} for config in unique_configs] + unique_configs = [SplitConfig({c[0]: list(c[1]) for c in config}) for config in unique_configs] return unique_configs def remove_multiple_nodes(configs: List[SplitConfig]) -> List[SplitConfig]: - return [config for config in configs if number_nonzero_nodes(config) == 1] + return [config for config in configs if config.number_nonzero_nodes() == 1] def remove_single_part_configs(configs: List[SplitConfig]) -> List[SplitConfig]: - return [config for config in configs if number_parts(config) != 1] + return [config for config in configs if config.number_parts() != 1] def remove_single_channel_splits(configs: List[SplitConfig]) -> List[SplitConfig]: @@ -107,7 +103,7 @@ def rate_config( lowest (best). A penalty depending on the total amount sent over a channel counteracts channel exhaustion.""" rating = 0 - total_amount = total_config_amount(config) + total_amount = config.total_config_amount() for channel, amounts in config.items(): funds = channels_with_funds[channel] @@ -143,7 +139,7 @@ def suggest_splits( for _ in range(CANDIDATES_PER_LEVEL): # we want to have configurations with no splitting to many splittings for target_parts in range(1, MAX_PARTS): - config = defaultdict(list) # type: SplitConfig + config = SplitConfig() # randomly split amount into target_parts chunks split_amounts = split_amount_normal(amount_msat, target_parts) @@ -152,6 +148,8 @@ def suggest_splits( random.shuffle(channels_order) # we check each channel and try to put the funds inside, break if we succeed for c in channels_order: + if c not in config: + config[c] = [] if sum(config[c]) + amount <= channels_with_funds[c]: config[c].append(amount) break @@ -167,11 +165,11 @@ def suggest_splits( distribute_amount -= add_amount if distribute_amount == 0: break - if total_config_amount(config) != amount_msat: + if config.total_config_amount() != amount_msat: raise NoPathFound('Cannot distribute payment over channels.') - if target_parts > 1 and is_any_amount_smaller_than_min_part_size(config): + if target_parts > 1 and config.is_any_amount_smaller_than_min_part_size(): continue - assert total_config_amount(config) == amount_msat + assert config.total_config_amount() == amount_msat configs.append(config) configs = remove_duplicates(configs) diff --git a/electrum/tests/test_mpp_split.py b/electrum/tests/test_mpp_split.py index efe127d0a..77c5e08b8 100644 --- a/electrum/tests/test_mpp_split.py +++ b/electrum/tests/test_mpp_split.py @@ -15,10 +15,10 @@ class TestMppSplit(ElectrumTestCase): random.seed(0) # key tuple denotes (channel_id, node_id) self.channels_with_funds = { - (0, 0): 1_000_000_000, - (1, 1): 500_000_000, - (2, 0): 302_000_000, - (3, 2): 101_000_000, + (b"0", b"0"): 1_000_000_000, + (b"1", b"1"): 500_000_000, + (b"2", b"0"): 302_000_000, + (b"3", b"2"): 101_000_000, } def tearDown(self): @@ -30,52 +30,52 @@ class TestMppSplit(ElectrumTestCase): with self.subTest(msg="do a payment with the maximal amount spendable over a single channel"): splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=True) self.assertEqual({ - (0, 0): [671_020_676], - (1, 1): [328_979_324], - (2, 0): [], - (3, 2): []}, + (b"0", b"0"): [671_020_676], + (b"1", b"1"): [328_979_324], + (b"2", b"0"): [], + (b"3", b"2"): []}, splits[0].config ) with self.subTest(msg="payment amount that does not require to be split"): splits = mpp_split.suggest_splits(50_000_000, self.channels_with_funds, exclude_single_part_payments=False) - self.assertEqual({(0, 0): [50_000_000]}, splits[0].config) - self.assertEqual({(1, 1): [50_000_000]}, splits[1].config) - self.assertEqual({(2, 0): [50_000_000]}, splits[2].config) - self.assertEqual({(3, 2): [50_000_000]}, splits[3].config) - self.assertEqual(2, mpp_split.number_parts(splits[4].config)) + self.assertEqual({(b"0", b"0"): [50_000_000]}, splits[0].config) + self.assertEqual({(b"1", b"1"): [50_000_000]}, splits[1].config) + self.assertEqual({(b"2", b"0"): [50_000_000]}, splits[2].config) + self.assertEqual({(b"3", b"2"): [50_000_000]}, splits[3].config) + self.assertEqual(2, splits[4].config.number_parts()) with self.subTest(msg="do a payment with a larger amount than what is supported by a single channel"): splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds, exclude_single_part_payments=False) - self.assertEqual(2, mpp_split.number_parts(splits[0].config)) + self.assertEqual(2, splits[0].config.number_parts()) with self.subTest(msg="do a payment with the maximal amount spendable over all channels"): splits = mpp_split.suggest_splits( sum(self.channels_with_funds.values()), self.channels_with_funds, exclude_single_part_payments=True) self.assertEqual({ - (0, 0): [1_000_000_000], - (1, 1): [500_000_000], - (2, 0): [302_000_000], - (3, 2): [101_000_000]}, + (b"0", b"0"): [1_000_000_000], + (b"1", b"1"): [500_000_000], + (b"2", b"0"): [302_000_000], + (b"3", b"2"): [101_000_000]}, splits[0].config ) with self.subTest(msg="do a payment with the amount supported by all channels"): splits = mpp_split.suggest_splits(101_000_000, self.channels_with_funds, exclude_single_part_payments=False) for split in splits[:3]: - self.assertEqual(1, mpp_split.number_nonzero_channels(split.config)) + self.assertEqual(1, split.config.number_nonzero_channels()) # due to exhaustion of the smallest channel, the algorithm favors # a splitting of the parts into two - self.assertEqual(2, mpp_split.number_parts(splits[4].config)) + self.assertEqual(2, splits[4].config.number_parts()) def test_send_to_single_node(self): splits = mpp_split.suggest_splits(1_000_000_000, self.channels_with_funds, exclude_single_part_payments=False, exclude_multinode_payments=True) for split in splits: - assert mpp_split.number_nonzero_nodes(split.config) == 1 + assert split.config.number_nonzero_nodes() == 1 def test_saturation(self): """Split configurations which spend the full amount in a channel should be avoided.""" - channels_with_funds = {(0, 0): 159_799_733_076, (1, 1): 499_986_152_000} + channels_with_funds = {(b"0", b"0"): 159_799_733_076, (b"1", b"1"): 499_986_152_000} splits = mpp_split.suggest_splits(600_000_000_000, channels_with_funds, exclude_single_part_payments=True) uses_full_amount = False @@ -100,37 +100,37 @@ class TestMppSplit(ElectrumTestCase): with self.subTest(msg="split payments with intermediate part penalty"): mpp_split.PART_PENALTY = 1.0 splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) - self.assertEqual(2, mpp_split.number_parts(splits[0].config)) + self.assertEqual(2, splits[0].config.number_parts()) with self.subTest(msg="split payments with intermediate part penalty"): mpp_split.PART_PENALTY = 0.3 splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) - self.assertEqual(4, mpp_split.number_parts(splits[0].config)) + self.assertEqual(4, splits[0].config.number_parts()) with self.subTest(msg="split payments with no part penalty"): mpp_split.PART_PENALTY = 0.0 splits = mpp_split.suggest_splits(1_100_000_000, self.channels_with_funds) - self.assertEqual(5, mpp_split.number_parts(splits[0].config)) + self.assertEqual(5, splits[0].config.number_parts()) def test_suggest_splits_single_channel(self): channels_with_funds = { - (0, 0): 1_000_000_000, + (b"0", b"0"): 1_000_000_000, } with self.subTest(msg="do a payment with the maximal amount spendable on a single channel"): splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False) - self.assertEqual({(0, 0): [1_000_000_000]}, splits[0].config) + self.assertEqual({(b"0", b"0"): [1_000_000_000]}, splits[0].config) with self.subTest(msg="test sending an amount greater than what we have available"): self.assertRaises(NoPathFound, mpp_split.suggest_splits, *(1_100_000_000, channels_with_funds)) with self.subTest(msg="test sending a large amount over a single channel in chunks"): mpp_split.PART_PENALTY = 0.5 splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False) - self.assertEqual(2, len(splits[0].config[(0, 0)])) + self.assertEqual(2, len(splits[0].config[(b"0", b"0")])) with self.subTest(msg="test sending a large amount over a single channel in chunks"): mpp_split.PART_PENALTY = 0.3 splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_part_payments=False) - self.assertEqual(3, len(splits[0].config[(0, 0)])) + self.assertEqual(3, len(splits[0].config[(b"0", b"0")])) with self.subTest(msg="exclude all single channel splits"): mpp_split.PART_PENALTY = 0.3 splits = mpp_split.suggest_splits(1_000_000_000, channels_with_funds, exclude_single_channel_splits=True) - self.assertEqual(1, len(splits[0].config[(0, 0)])) + self.assertEqual(1, len(splits[0].config[(b"0", b"0")]))