diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 0cb73b264..22d6beed3 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -183,6 +183,8 @@ class LiquidityHint: self._cannot_send_backward = None self.blacklist_timestamp = 0 self.hint_timestamp = 0 + self._inflight_htlcs_forward = 0 + self._inflight_htlcs_backward = 0 def is_hint_invalid(self) -> bool: now = int(time.time()) @@ -273,10 +275,28 @@ class LiquidityHint: else: self.cannot_send_backward = amount + def num_inflight_htlcs(self, is_forward_direction: bool) -> int: + if is_forward_direction: + return self._inflight_htlcs_forward + else: + return self._inflight_htlcs_backward + + def add_htlc(self, is_forward_direction: bool): + if is_forward_direction: + self._inflight_htlcs_forward += 1 + else: + self._inflight_htlcs_backward += 1 + + def remove_htlc(self, is_forward_direction: bool): + if is_forward_direction: + self._inflight_htlcs_forward = max(0, self._inflight_htlcs_forward - 1) + else: + self._inflight_htlcs_backward = max(0, self._inflight_htlcs_forward - 1) + def __repr__(self): is_blacklisted = False if not self.blacklist_timestamp else int(time.time()) - self.blacklist_timestamp < BLACKLIST_DURATION - return f"forward: can send: {self._can_send_forward} msat, cannot send: {self._cannot_send_forward} msat, \n" \ - f"backward: can send: {self._can_send_backward} msat, cannot send: {self._cannot_send_backward} msat, \n" \ + return f"forward: can send: {self._can_send_forward} msat, cannot send: {self._cannot_send_forward} msat, htlcs: {self._inflight_htlcs_forward}\n" \ + f"backward: can send: {self._can_send_backward} msat, cannot send: {self._cannot_send_backward} msat, htlcs: {self._inflight_htlcs_backward}\n" \ f"blacklisted: {is_blacklisted}" @@ -288,15 +308,13 @@ class LiquidityHintMgr: algorithm that favors channels which can route payments and penalizes channels that cannot. """ - # TODO: incorporate in-flight htlcs - # TODO: use timestamps for can/not_send to make them None after some time? # TODO: hints based on node pairs only (shadow channels, non-strict forwarding)? def __init__(self): self.lock = RLock() self._liquidity_hints: Dict[ShortChannelID, LiquidityHint] = {} @with_lock - def get_hint(self, channel_id: ShortChannelID): + def get_hint(self, channel_id: ShortChannelID) -> LiquidityHint: hint = self._liquidity_hints.get(channel_id) if not hint: hint = LiquidityHint() @@ -313,6 +331,16 @@ class LiquidityHintMgr: hint = self.get_hint(channel_id) hint.update_cannot_send(node_from < node_to, amount) + @with_lock + def add_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID): + hint = self.get_hint(channel_id) + hint.add_htlc(node_from < node_to) + + @with_lock + def remove_htlc(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID): + hint = self.get_hint(channel_id) + hint.remove_htlc(node_from < node_to) + def penalty(self, node_from: bytes, node_to: bytes, channel_id: ShortChannelID, amount: int) -> float: """Gives a penalty when sending from node1 to node2 over channel_id with an amount in units of millisatoshi. @@ -337,16 +365,19 @@ class LiquidityHintMgr: # we only evaluate hints here, so use dict get (to not create many hints with self.get_hint) hint = self._liquidity_hints.get(channel_id) if not hint: - can_send, cannot_send = None, None + can_send, cannot_send, num_inflight_htlcs = None, None, 0 else: can_send = hint.can_send(node_from < node_to) cannot_send = hint.cannot_send(node_from < node_to) + num_inflight_htlcs = hint.num_inflight_htlcs(node_from < node_to) if cannot_send is not None and amount >= cannot_send: return inf if can_send is not None and amount <= can_send: return 0 - return fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH) + success_fee = fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH) + inflight_htlc_fee = num_inflight_htlcs * success_fee + return success_fee + inflight_htlc_fee @with_lock def add_to_blacklist(self, channel_id: ShortChannelID): @@ -403,6 +434,14 @@ class LNPathFinder(Logger): self.liquidity_hints.update_cannot_send(r.start_node, r.end_node, r.short_channel_id, amount_msat) break + def update_inflight_htlcs(self, route: LNPaymentRoute, add_htlcs: bool): + self.logger.info(f"{'Adding' if add_htlcs else 'Removing'} inflight htlcs to graph (liquidity hints).") + for r in route: + if add_htlcs: + self.liquidity_hints.add_htlc(r.start_node, r.end_node, r.short_channel_id) + else: + self.liquidity_hints.remove_htlc(r.start_node, r.end_node, r.short_channel_id) + def _edge_cost( self, *, diff --git a/electrum/lnworker.py b/electrum/lnworker.py index e96bb5da1..fd7c5e03c 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -8,7 +8,7 @@ from decimal import Decimal import random import time from typing import (Optional, Sequence, Tuple, List, Set, Dict, TYPE_CHECKING, - NamedTuple, Union, Mapping, Any, Iterable) + NamedTuple, Union, Mapping, Any, Iterable, AsyncGenerator) import threading import socket import aiohttp @@ -1073,20 +1073,6 @@ class LNWallet(LNWorker): if chan.short_channel_id == short_channel_id: return chan - def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None): - return self.create_routes_for_payment( - amount_msat=amount_msat, - final_total_msat=amount_msat, - invoice_pubkey=decoded_invoice.pubkey.serialize(), - min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), - r_tags=decoded_invoice.get_routing_info('r'), - invoice_features=decoded_invoice.get_features(), - trampoline_fee_level=0, - use_two_trampolines=False, - payment_hash=decoded_invoice.paymenthash, - payment_secret=decoded_invoice.payment_secret, - full_path=full_path) - @log_exceptions async def pay_invoice( self, invoice: str, *, @@ -1173,8 +1159,7 @@ class LNWallet(LNWorker): # 1. create a set of routes for remaining amount. # note: path-finding runs in a separate thread so that we don't block the asyncio loop # graph updates might occur during the computation - routes = await run_in_thread(partial( - self.create_routes_for_payment, + routes = self.create_routes_for_payment( amount_msat=amount_to_send, final_total_msat=amount_to_pay, invoice_pubkey=node_pubkey, @@ -1186,9 +1171,10 @@ class LNWallet(LNWorker): payment_secret=payment_secret, trampoline_fee_level=trampoline_fee_level, use_two_trampolines=use_two_trampolines, - fwd_trampoline_onion=fwd_trampoline_onion)) + fwd_trampoline_onion=fwd_trampoline_onion + ) # 2. send htlcs - for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes: + async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion in routes: amount_inflight += amount_receiver_msat if amount_inflight > amount_to_pay: # safety belts raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}") @@ -1210,12 +1196,14 @@ class LNWallet(LNWorker): raise Exception(f"amount_inflight={amount_inflight} < 0") log.append(htlc_log) if htlc_log.success: - # TODO: report every route to liquidity hints for mpp - # even in the case of success, we report channels of the - # route as being able to send the same amount in the future, - # as we assume to not know the capacity if self.network.path_finder: + # TODO: report every route to liquidity hints for mpp + # in the case of success, we report channels of the + # route as being able to send the same amount in the future, + # as we assume to not know the capacity self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat) + # remove inflight htlcs from liquidity hints + self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False) return # htlc failed if len(log) >= attempts: @@ -1282,6 +1270,9 @@ class LNWallet(LNWorker): amount_sent, amount_failed = self.sent_buckets[payment_secret] amount_sent += amount_receiver_msat self.sent_buckets[payment_secret] = amount_sent, amount_failed + if self.network.path_finder: + # add inflight htlcs to liquidity hints + self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) util.trigger_callback('htlc_added', chan, htlc, SENT) def handle_error_code_from_failed_htlc( @@ -1291,6 +1282,13 @@ class LNWallet(LNWorker): sender_idx: int, failure_msg: OnionRoutingFailure, amount: int) -> None: + + assert self.channel_db # cannot be in trampoline mode + assert self.network.path_finder + + # remove inflight htlcs from liquidity hints + self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False) + code, data = failure_msg.code, failure_msg.data # TODO can we use lnmsg.OnionWireSerializer here? # TODO update onion_wire.csv @@ -1432,8 +1430,7 @@ class LNWallet(LNWorker): else: return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey - @profiler - def create_routes_for_payment( + async def create_routes_for_payment( self, *, amount_msat: int, # part of payment amount we want routes for now final_total_msat: int, # total payment amount final receiver will get @@ -1446,7 +1443,7 @@ class LNWallet(LNWorker): trampoline_fee_level: int, use_two_trampolines: bool, fwd_trampoline_onion = None, - full_path: LNPaymentPath = None) -> Sequence[Tuple[LNPaymentRoute, int]]: + full_path: LNPaymentPath = None) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]: """Creates multiple routes for splitting a payment over the available private channels. @@ -1502,20 +1499,24 @@ class LNWallet(LNWorker): cltv_expiry_delta=0, node_features=trampoline_features) ] - routes = [(route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion)] + yield route, amount_with_fees, trampoline_total_msat, amount_msat, cltv_delta, trampoline_payment_secret, trampoline_onion break else: raise NoPathFound() else: - route = self.create_route_for_payment( - amount_msat=amount_msat, - invoice_pubkey=invoice_pubkey, - min_cltv_expiry=min_cltv_expiry, - r_tags=r_tags, - invoice_features=invoice_features, - channels=active_channels, - full_path=full_path) - routes = [(route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion)] + route = await run_in_thread( + partial( + self.create_route_for_payment, + amount_msat=amount_msat, + invoice_pubkey=invoice_pubkey, + min_cltv_expiry=min_cltv_expiry, + r_tags=r_tags, + invoice_features=invoice_features, + channels=active_channels, + full_path=full_path + ) + ) + yield route, amount_msat, final_total_msat, amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion except NoPathFound: if not invoice_features.supports(LnFeatures.BASIC_MPP_OPT): raise @@ -1532,7 +1533,6 @@ class LNWallet(LNWorker): for s in split_configurations: self.logger.info(f"trying split configuration: {s[0].values()} rating: {s[1]}") - routes = [] try: if not self.channel_db: buckets = defaultdict(list) @@ -1577,7 +1577,7 @@ class LNWallet(LNWorker): node_features=trampoline_features) ] self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') - routes.append((route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion)) + yield route, part_amount_msat_with_fees, bucket_amount_with_fees, part_amount_msat, bucket_cltv_delta, bucket_payment_secret, trampoline_onion if bucket_fees != 0: self.logger.info('not enough margin to pay trampoline fee') raise NoPathFound() @@ -1585,23 +1585,27 @@ class LNWallet(LNWorker): for (chan_id, _), part_amount_msat in s[0].items(): if part_amount_msat: channel = self.channels[chan_id] - route = self.create_route_for_payment( - amount_msat=part_amount_msat, - invoice_pubkey=invoice_pubkey, - min_cltv_expiry=min_cltv_expiry, - r_tags=r_tags, - invoice_features=invoice_features, - channels=[channel], - full_path=None) - routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion)) + route = await run_in_thread( + partial( + self.create_route_for_payment, + amount_msat=part_amount_msat, + invoice_pubkey=invoice_pubkey, + min_cltv_expiry=min_cltv_expiry, + r_tags=r_tags, + invoice_features=invoice_features, + channels=[channel], + full_path=None + ) + ) + yield route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion self.logger.info(f"found acceptable split configuration: {list(s[0].values())} rating: {s[1]}") break except NoPathFound: continue else: raise NoPathFound() - return routes + @profiler def create_route_for_payment( self, *, amount_msat: int, @@ -1610,7 +1614,7 @@ class LNWallet(LNWorker): r_tags, invoice_features: int, channels: List[Channel], - full_path: Optional[LNPaymentPath]) -> Tuple[LNPaymentRoute, int]: + full_path: Optional[LNPaymentPath]) -> LNPaymentRoute: scid_to_my_channels = { chan.short_channel_id: chan for chan in channels diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index 20cd37bb2..81a17b7cc 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -192,6 +192,20 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): self.channel_db.stop() await self.channel_db.stopped_event.wait() + async def create_routes_from_invoice(self, amount_msat: int, decoded_invoice: LnAddr, *, full_path=None): + return [r async for r in self.create_routes_for_payment( + amount_msat=amount_msat, + final_total_msat=amount_msat, + invoice_pubkey=decoded_invoice.pubkey.serialize(), + min_cltv_expiry=decoded_invoice.get_min_final_cltv_expiry(), + r_tags=decoded_invoice.get_routing_info('r'), + invoice_features=decoded_invoice.get_features(), + trampoline_fee_level=0, + use_two_trampolines=False, + payment_hash=decoded_invoice.paymenthash, + payment_secret=decoded_invoice.payment_secret, + full_path=full_path)] + get_payments = LNWallet.get_payments get_payment_info = LNWallet.get_payment_info save_payment_info = LNWallet.save_payment_info @@ -206,7 +220,6 @@ class MockLNWallet(Logger, NetworkRetryManager[LNPeerAddr]): get_preimage = LNWallet.get_preimage create_route_for_payment = LNWallet.create_route_for_payment create_routes_for_payment = LNWallet.create_routes_for_payment - create_routes_from_invoice = LNWallet.create_routes_from_invoice _check_invoice = staticmethod(LNWallet._check_invoice) pay_to_route = LNWallet.pay_to_route pay_to_node = LNWallet.pay_to_node @@ -598,7 +611,7 @@ class TestPeer(TestCaseForTestnet): q2 = w2.sent_htlcs[lnaddr1.paymenthash] # alice sends htlc BUT NOT COMMITMENT_SIGNED p1.maybe_send_commitment = lambda x: None - route1 = w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2)[0][0] + route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0] amount_msat = lnaddr2.get_amount_msat() await w1.pay_to_route( route=route1, @@ -612,7 +625,7 @@ class TestPeer(TestCaseForTestnet): p1.maybe_send_commitment = _maybe_send_commitment1 # bob sends htlc BUT NOT COMMITMENT_SIGNED p2.maybe_send_commitment = lambda x: None - route2 = w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1)[0][0] + route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0] amount_msat = lnaddr1.get_amount_msat() await w2.pay_to_route( route=route2, @@ -982,14 +995,14 @@ class TestPeer(TestCaseForTestnet): await asyncio.wait_for(p1.initialized, 1) await asyncio.wait_for(p2.initialized, 1) # alice sends htlc - route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2] - htlc = p1.pay(route=route, - chan=alice_channel, - amount_msat=lnaddr.get_amount_msat(), - total_msat=lnaddr.get_amount_msat(), - payment_hash=lnaddr.paymenthash, - min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), - payment_secret=lnaddr.payment_secret) + route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] + p1.pay(route=route, + chan=alice_channel, + amount_msat=lnaddr.get_amount_msat(), + total_msat=lnaddr.get_amount_msat(), + payment_hash=lnaddr.paymenthash, + min_final_cltv_expiry=lnaddr.get_min_final_cltv_expiry(), + payment_secret=lnaddr.payment_secret) # alice closes await p1.close_channel(alice_channel.channel_id) gath.cancel() @@ -1078,7 +1091,7 @@ class TestPeer(TestCaseForTestnet): lnaddr, pay_req = run(self.prepare_invoice(w2)) lnaddr = w1._check_invoice(pay_req) - route, amount_msat = w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)[0][0:2] + route, amount_msat = run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] assert amount_msat == lnaddr.get_amount_msat() run(w1.force_close_channel(alice_channel.channel_id)) @@ -1086,7 +1099,7 @@ class TestPeer(TestCaseForTestnet): assert q1.qsize() == 1 with self.assertRaises(NoPathFound) as e: - w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr) + run(w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr)) peer = w1.peers[route[0].node_id] # AssertionError is ok since we shouldn't use old routes, and the diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py index 5382fe502..09b045f6e 100644 --- a/electrum/tests/test_lnrouter.py +++ b/electrum/tests/test_lnrouter.py @@ -28,12 +28,18 @@ def node(character: str) -> bytes: class Test_LNRouter(TestCaseForTestnet): + cdb = None + def setUp(self): super().setUp() self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() self.config = SimpleConfig({'electrum_path': self.electrum_path}) def tearDown(self): + # if the test called prepare_graph(), channeldb needs to be cleaned up + if self.cdb: + self.cdb.stop() + asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) self._loop_thread.join(timeout=1) super().tearDown() @@ -151,10 +157,7 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(node('b'), route[0].node_id) self.assertEqual(channel(3), route[0].short_channel_id) - self.cdb.stop() - asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() - - def test_find_path_liquidity_hints_failure(self): + def test_find_path_liquidity_hints(self): self.prepare_graph() amount_to_send = 100000 @@ -197,7 +200,7 @@ class Test_LNRouter(TestCaseForTestnet): assume success over channel 4, D -> C A -3-> B |-2-> E A -6-> D |-5-> E - A -6-> D -4-> C -7-> E <= chosen path + A -6-> D -4-> C -7-> E <= smaller penalty: chosen path A -3-> B -1-> C -7-> E A -6-> D -4-> C -1-> B |-2-> E A -3-> B -1-> C -4-> D |-5-> E @@ -211,8 +214,43 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(channel(4), path[1].short_channel_id) self.assertEqual(channel(7), path[2].short_channel_id) - self.cdb.stop() - asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() + def test_find_path_liquidity_hints_inflight_htlcs(self): + self.prepare_graph() + amount_to_send = 100000 + + """ + add inflight htlc to channel 2, B -> E + A -3-> B -2(1)-> E + A -6-> D -5-> E <= chosen path + A -6-> D -4-> C -7-> E + A -3-> B -1-> C -7-> E + A -6-> D -4-> C -1-> B -2-> E + A -3-> B -1-> C -4-> D -5-> E + """ + self.path_finder.liquidity_hints.add_htlc(node('b'), node('e'), channel(2)) + path = self.path_finder.find_path_for_payment( + nodeA=node('a'), + nodeB=node('e'), + invoice_amount_msat=amount_to_send) + self.assertEqual(channel(6), path[0].short_channel_id) + self.assertEqual(channel(5), path[1].short_channel_id) + + """ + remove inflight htlc from channel 2, B -> E + A -3-> B -2(0)-> E <= chosen path + A -6-> D -5-> E + A -6-> D -4-> C -7-> E + A -3-> B -1-> C -7-> E + A -6-> D -4-> C -1-> B -2-> E + A -3-> B -1-> C -4-> D -5-> E + """ + self.path_finder.liquidity_hints.remove_htlc(node('b'), node('e'), channel(2)) + path = self.path_finder.find_path_for_payment( + nodeA=node('a'), + nodeB=node('e'), + invoice_amount_msat=amount_to_send) + self.assertEqual(channel(3), path[0].short_channel_id) + self.assertEqual(channel(2), path[1].short_channel_id) def test_liquidity_hints(self): liquidity_hints = LiquidityHintMgr() @@ -251,6 +289,12 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(3_000_000, hint.can_send(node_from < node_to)) self.assertEqual(None, hint.cannot_send(node_from < node_to)) + # test inflight htlc + liquidity_hints.reset_liquidity_hints() + liquidity_hints.add_htlc(node_from, node_to, channel_id) + liquidity_hints.get_hint(channel_id) + # we have got 600 (attempt) + 600 (inflight) penalty + self.assertEqual(1200, liquidity_hints.penalty(node_from, node_to, channel_id, 1_000_000)) @needs_test_with_all_chacha20_implementations def test_new_onion_packet_legacy(self):