diff --git a/electrum/lnrouter.py b/electrum/lnrouter.py index 3400101a2..22d6beed3 100644 --- a/electrum/lnrouter.py +++ b/electrum/lnrouter.py @@ -275,7 +275,7 @@ class LiquidityHint: else: self.cannot_send_backward = amount - def inflight_htlcs(self, is_forward_direction: bool): + def num_inflight_htlcs(self, is_forward_direction: bool) -> int: if is_forward_direction: return self._inflight_htlcs_forward else: @@ -365,18 +365,18 @@ class LiquidityHintMgr: # we only evaluate hints here, so use dict get (to not create many hints with self.get_hint) hint = self._liquidity_hints.get(channel_id) if not hint: - can_send, cannot_send, inflight_htlcs = None, None, 0 + can_send, cannot_send, num_inflight_htlcs = None, None, 0 else: can_send = hint.can_send(node_from < node_to) cannot_send = hint.cannot_send(node_from < node_to) - inflight_htlcs = hint.inflight_htlcs(node_from < node_to) + num_inflight_htlcs = hint.num_inflight_htlcs(node_from < node_to) if cannot_send is not None and amount >= cannot_send: return inf if can_send is not None and amount <= can_send: return 0 success_fee = fee_for_edge_msat(amount, DEFAULT_PENALTY_BASE_MSAT, DEFAULT_PENALTY_PROPORTIONAL_MILLIONTH) - inflight_htlc_fee = inflight_htlcs * success_fee + inflight_htlc_fee = num_inflight_htlcs * success_fee return success_fee + inflight_htlc_fee @with_lock @@ -434,8 +434,8 @@ class LNPathFinder(Logger): self.liquidity_hints.update_cannot_send(r.start_node, r.end_node, r.short_channel_id, amount_msat) break - def update_htlcs_liquidity_hints(self, route: LNPaymentRoute, add_htlcs: bool): - self.logger.info(f"{'Adding' if add_htlcs else 'Removing'} htlcs in liquidity hints.") + def update_inflight_htlcs(self, route: LNPaymentRoute, add_htlcs: bool): + self.logger.info(f"{'Adding' if add_htlcs else 'Removing'} inflight htlcs to graph (liquidity hints).") for r in route: if add_htlcs: self.liquidity_hints.add_htlc(r.start_node, r.end_node, r.short_channel_id) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 6f9c2c843..fd7c5e03c 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -1203,7 +1203,7 @@ class LNWallet(LNWorker): # as we assume to not know the capacity self.network.path_finder.update_liquidity_hints(htlc_log.route, htlc_log.amount_msat) # remove inflight htlcs from liquidity hints - self.network.path_finder.update_htlcs_liquidity_hints(htlc_log.route, add_htlcs=False) + self.network.path_finder.update_inflight_htlcs(htlc_log.route, add_htlcs=False) return # htlc failed if len(log) >= attempts: @@ -1270,8 +1270,9 @@ class LNWallet(LNWorker): amount_sent, amount_failed = self.sent_buckets[payment_secret] amount_sent += amount_receiver_msat self.sent_buckets[payment_secret] = amount_sent, amount_failed - # add inflight htlcs to liquidity hints - self.network.path_finder.update_htlcs_liquidity_hints(route, add_htlcs=True) + if self.network.path_finder: + # add inflight htlcs to liquidity hints + self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) util.trigger_callback('htlc_added', chan, htlc, SENT) def handle_error_code_from_failed_htlc( @@ -1281,6 +1282,13 @@ class LNWallet(LNWorker): sender_idx: int, failure_msg: OnionRoutingFailure, amount: int) -> None: + + assert self.channel_db # cannot be in trampoline mode + assert self.network.path_finder + + # remove inflight htlcs from liquidity hints + self.network.path_finder.update_inflight_htlcs(route, add_htlcs=False) + code, data = failure_msg.code, failure_msg.data # TODO can we use lnmsg.OnionWireSerializer here? # TODO update onion_wire.csv @@ -1337,8 +1345,6 @@ class LNWallet(LNWorker): # for errors that do not include a channel update else: self.network.path_finder.liquidity_hints.add_to_blacklist(fallback_channel) - # remove inflight htlcs from liquidity hints - self.network.path_finder.update_htlcs_liquidity_hints(route, add_htlcs=False) def _handle_chanupd_from_failed_htlc(self, payload, *, route, sender_idx) -> Tuple[bool, bool]: blacklist = False diff --git a/electrum/tests/test_lnrouter.py b/electrum/tests/test_lnrouter.py index a292aadda..09b045f6e 100644 --- a/electrum/tests/test_lnrouter.py +++ b/electrum/tests/test_lnrouter.py @@ -28,12 +28,18 @@ def node(character: str) -> bytes: class Test_LNRouter(TestCaseForTestnet): + cdb = None + def setUp(self): super().setUp() self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() self.config = SimpleConfig({'electrum_path': self.electrum_path}) def tearDown(self): + # if the test called prepare_graph(), channeldb needs to be cleaned up + if self.cdb: + self.cdb.stop() + asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) self._loop_thread.join(timeout=1) super().tearDown() @@ -151,9 +157,6 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(node('b'), route[0].node_id) self.assertEqual(channel(3), route[0].short_channel_id) - self.cdb.stop() - asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() - def test_find_path_liquidity_hints(self): self.prepare_graph() amount_to_send = 100000 @@ -249,9 +252,6 @@ class Test_LNRouter(TestCaseForTestnet): self.assertEqual(channel(3), path[0].short_channel_id) self.assertEqual(channel(2), path[1].short_channel_id) - self.cdb.stop() - asyncio.run_coroutine_threadsafe(self.cdb.stopped_event.wait(), self.asyncio_loop).result() - def test_liquidity_hints(self): liquidity_hints = LiquidityHintMgr() node_from = bytes(0)