From 2af59e32b2c03961aa57d6d6872f6099ed8890f5 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Mon, 19 Sep 2022 17:43:13 +0200 Subject: [PATCH] lnworker: define use_trampoline() for code clarity --- electrum/gui/qt/channels_list.py | 4 ++-- electrum/lnchannel.py | 2 +- electrum/lnpeer.py | 8 ++++---- electrum/lnworker.py | 33 +++++++++++++++++--------------- electrum/tests/test_lnpeer.py | 3 +++ 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/electrum/gui/qt/channels_list.py b/electrum/gui/qt/channels_list.py index 73ed65f36..2bf94b042 100644 --- a/electrum/gui/qt/channels_list.py +++ b/electrum/gui/qt/channels_list.py @@ -183,7 +183,7 @@ class ChannelsList(MyTreeView): WaitingDialog(self, 'please wait..', task, self.on_request_sent, self.on_failure) def freeze_channel_for_sending(self, chan, b): - if self.lnworker.channel_db or self.lnworker.is_trampoline_peer(chan.node_id): + if not self.lnworker.uses_trampoline() or self.lnworker.is_trampoline_peer(chan.node_id): chan.set_frozen_for_sending(b) else: msg = messages.MSG_NON_TRAMPOLINE_CHANNEL_FROZEN_WITHOUT_GOSSIP @@ -198,7 +198,7 @@ class ChannelsList(MyTreeView): channel_id2 = idx2.sibling(idx2.row(), self.Columns.NODE_ALIAS).data(ROLE_CHANNEL_ID) chan1 = self.lnworker.channels.get(channel_id1) chan2 = self.lnworker.channels.get(channel_id2) - if chan1 and chan2 and (self.lnworker.channel_db or chan1.node_id != chan2.node_id): + if chan1 and chan2 and (not self.lnworker.uses_trampoline() or chan1.node_id != chan2.node_id): return chan1, chan2 return None, None diff --git a/electrum/lnchannel.py b/electrum/lnchannel.py index 4e084cf3a..322b7d2c1 100644 --- a/electrum/lnchannel.py +++ b/electrum/lnchannel.py @@ -824,7 +824,7 @@ class Channel(AbstractChannel): return self.can_send_ctx_updates() and self.is_open() def is_frozen_for_sending(self) -> bool: - if self.lnworker and self.lnworker.channel_db is None and not self.lnworker.is_trampoline_peer(self.node_id): + if self.lnworker and self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.node_id): return True return self.storage.get('frozen_for_sending', False) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index a5ff93716..67676a3c7 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -366,16 +366,16 @@ class Peer(Logger): self.maybe_set_initialized() def on_node_announcement(self, payload): - if self.lnworker.channel_db: + if not self.lnworker.uses_trampoline(): self.gossip_queue.put_nowait(('node_announcement', payload)) def on_channel_announcement(self, payload): - if self.lnworker.channel_db: + if not self.lnworker.uses_trampoline(): self.gossip_queue.put_nowait(('channel_announcement', payload)) def on_channel_update(self, payload): self.maybe_save_remote_update(payload) - if self.lnworker.channel_db: + if not self.lnworker.uses_trampoline(): self.gossip_queue.put_nowait(('channel_update', payload)) def maybe_save_remote_update(self, payload): @@ -702,7 +702,7 @@ class Peer(Logger): # will raise if init fails await asyncio.wait_for(self.initialized, LN_P2P_NETWORK_TIMEOUT) # trampoline is not yet in features - if not self.lnworker.channel_db and not self.lnworker.is_trampoline_peer(self.pubkey): + if self.lnworker.uses_trampoline() and not self.lnworker.is_trampoline_peer(self.pubkey): raise Exception('Not a trampoline node: ' + str(self.their_features)) feerate = self.lnworker.current_feerate_per_kw() diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 14cda39ba..56204f0a7 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -223,6 +223,9 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): def channel_db(self): return self.network.channel_db if self.network else None + def uses_trampoline(self): + return not bool(self.channel_db) + @property def peers(self) -> Mapping[bytes, Peer]: """Returns a read-only copy of peers.""" @@ -235,7 +238,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): def get_node_alias(self, node_id: bytes) -> Optional[str]: """Returns the alias of the node, or None if unknown.""" node_alias = None - if self.channel_db: + if not self.uses_trampoline(): node_info = self.channel_db.get_node_info_for_node_id(node_id) if node_info: node_alias = node_info.alias @@ -371,8 +374,8 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): peer_addr = peer.transport.peer_addr # reset connection attempt count self._on_connection_successfully_established(peer_addr) - # add into channel db - if self.channel_db: + if not self.uses_trampoline(): + # add into channel db self.channel_db.add_recent_peer(peer_addr) # save network address into channels we might have with peer for chan in peer.channels.values(): @@ -492,7 +495,7 @@ class LNWorker(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): if rest is not None: host, port = split_host_port(rest) else: - if not self.channel_db: + if self.uses_trampoline(): addr = trampolines_by_id().get(node_id) if not addr: raise ConnStringFormatError(_('Address unknown for node:') + ' ' + bh2u(node_id)) @@ -1299,7 +1302,7 @@ class LNWallet(LNWorker): if code == OnionFailureCode.MPP_TIMEOUT: raise PaymentFailure(failure_msg.code_name()) # trampoline - if not self.channel_db: + if self.uses_trampoline(): def maybe_raise_trampoline_fee(htlc_log): if htlc_log.trampoline_fee_level == self.trampoline_fee_level: self.trampoline_fee_level += 1 @@ -1370,7 +1373,7 @@ class LNWallet(LNWorker): key = (payment_hash, short_channel_id, htlc.htlc_id) self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route # if we sent MPP to a trampoline, add item to sent_buckets - if not self.channel_db and amount_msat != total_msat: + if self.uses_trampoline() and amount_msat != total_msat: if payment_secret not in self.sent_buckets: self.sent_buckets[payment_secret] = (0, 0) amount_sent, amount_failed = self.sent_buckets[payment_secret] @@ -1531,7 +1534,7 @@ class LNWallet(LNWorker): return False def suggest_peer(self) -> Optional[bytes]: - if self.channel_db: + if not self.uses_trampoline(): return self.lnrater.suggest_peer() else: return random.choice(list(hardcoded_trampoline_nodes().values())).pubkey @@ -1572,7 +1575,7 @@ class LNWallet(LNWorker): try: self.logger.info("trying single-part payment") # try to send over a single channel - if not self.channel_db: + if self.uses_trampoline(): for chan in my_active_channels: if not self.is_trampoline_peer(chan.node_id): continue @@ -1640,7 +1643,7 @@ class LNWallet(LNWorker): for chan in my_active_channels} self.logger.info(f"channels_with_funds: {channels_with_funds}") - if not self.channel_db: + if self.uses_trampoline(): # in the case of a legacy payment, we don't allow splitting via different # trampoline nodes, because of https://github.com/ACINQ/eclair/issues/2127 use_single_node, _ = is_legacy_relay(invoice_features, r_tags) @@ -2043,7 +2046,7 @@ class LNWallet(LNWorker): self.logger.info(f"htlc_failed {failure_message}") # check sent_buckets if we use trampoline - if not self.channel_db and payment_secret in self.sent_buckets: + if self.uses_trampoline() and payment_secret in self.sent_buckets: amount_sent, amount_failed = self.sent_buckets[payment_secret] amount_failed += amount_receiver_msat self.sent_buckets[payment_secret] = amount_sent, amount_failed @@ -2163,7 +2166,7 @@ class LNWallet(LNWorker): can_send_dict = defaultdict(int) with self.lock: for c in self.get_channels_for_sending(): - if self.channel_db: + if not self.uses_trampoline(): can_send_dict[0] += send_capacity(c) else: can_send_dict[c.node_id] += send_capacity(c) @@ -2271,7 +2274,7 @@ class LNWallet(LNWorker): continue if chan1 == chan2: continue - if not self.channel_db and chan1.node_id == chan2.node_id: + if self.uses_trampoline() and chan1.node_id == chan2.node_id: continue if direction == SENT: if chan1.can_pay(delta*1000): @@ -2326,7 +2329,7 @@ class LNWallet(LNWorker): async def rebalance_channels(self, chan1, chan2, amount_msat): if chan1 == chan2: raise Exception('Rebalance requires two different channels') - if not self.channel_db and chan1.node_id == chan2.node_id: + if self.uses_trampoline() and chan1.node_id == chan2.node_id: raise Exception('Rebalance requires channels from different trampolines') lnaddr, invoice = self.create_invoice( amount_msat=amount_msat, @@ -2408,7 +2411,7 @@ class LNWallet(LNWorker): async def reestablish_peer_for_given_channel(self, chan: Channel) -> None: now = time.time() peer_addresses = [] - if not self.channel_db: + if self.uses_trampoline(): addr = trampolines_by_id().get(chan.node_id) if addr: peer_addresses.append(addr) @@ -2590,7 +2593,7 @@ class LNWallet(LNWorker): if success: return # try with gossip db - if not self.channel_db: + if self.uses_trampoline(): raise Exception(_('Please enable gossip')) node_id = self.network.channel_db.get_node_by_prefix(cb.node_id_prefix) addresses_from_gossip = self.network.channel_db.get_node_addresses(node_id) diff --git a/electrum/tests/test_lnpeer.py b/electrum/tests/test_lnpeer.py index dbcfb6a38..c6b7befc3 100644 --- a/electrum/tests/test_lnpeer.py +++ b/electrum/tests/test_lnpeer.py @@ -185,6 +185,9 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]): def channel_db(self): return self.network.channel_db if self.network else None + def uses_trampoline(self): + return not bool(self.channel_db) + @property def channels(self): return self._channels