From 4cdd199f5b068929ba6c0fe70a3977615d00ca66 Mon Sep 17 00:00:00 2001 From: SomberNight Date: Mon, 20 Nov 2023 11:34:56 +0000 Subject: [PATCH] lnworker: add/fix some type hints, add some comments follow-up recent refactor --- electrum/lnpeer.py | 34 +++++++++++++++++++--------------- electrum/lnutil.py | 6 ++++-- electrum/lnworker.py | 33 +++++++++++++++++++++------------ 3 files changed, 44 insertions(+), 29 deletions(-) diff --git a/electrum/lnpeer.py b/electrum/lnpeer.py index 91cb96a07..8985f5993 100644 --- a/electrum/lnpeer.py +++ b/electrum/lnpeer.py @@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict import asyncio import os import time -from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable +from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable from datetime import datetime import functools @@ -1693,7 +1693,8 @@ class Peer(Logger): self, *, incoming_chan: Channel, htlc: UpdateAddHtlc, - processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]: + processed_onion: ProcessedOnionPacket, + ) -> str: # Forward HTLC # FIXME: there are critical safety checks MISSING here @@ -1744,11 +1745,11 @@ class Peer(Logger): break else: return await self.lnworker.open_channel_just_in_time( - next_peer, - next_amount_msat_htlc, - next_cltv_abs, - htlc.payment_hash, - processed_onion.next_packet) + next_peer=next_peer, + next_amount_msat_htlc=next_amount_msat_htlc, + next_cltv_abs=next_cltv_abs, + payment_hash=htlc.payment_hash, + next_onion=processed_onion.next_packet) local_height = chain.height() if next_chan is None: @@ -1815,7 +1816,8 @@ class Peer(Logger): inc_cltv_abs: int, outer_onion: ProcessedOnionPacket, trampoline_onion: ProcessedOnionPacket, - fw_payment_key: str): + fw_payment_key: str, + ) -> None: forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS @@ -1905,11 +1907,11 @@ class Peer(Logger): trampoline_onion=next_trampoline_onion, ) await self.lnworker.open_channel_just_in_time( - next_peer, - amt_to_forward, - cltv_abs, - payment_hash, - next_onion) + next_peer=next_peer, + next_amount_msat_htlc=amt_to_forward, + next_cltv_abs=cltv_abs, + payment_hash=payment_hash, + next_onion=next_onion) return try: @@ -1957,8 +1959,8 @@ class Peer(Logger): htlc: UpdateAddHtlc, processed_onion: ProcessedOnionPacket, onion_packet_bytes: bytes, - already_forwarded = False, - ) -> Tuple[Optional[bytes], Optional[Callable]]: + already_forwarded: bool = False, + ) -> Tuple[Optional[str], Optional[bytes], Optional[Callable[[], Awaitable[Optional[str]]]]]: """ Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. Return (payment_key, preimage, callback) with at most a single element of the last two not None @@ -2637,6 +2639,7 @@ class Peer(Logger): # HTLC we are supposed to forward, but haven't forwarded yet if not self.lnworker.enable_htlc_forwarding: return None, None, None + assert payment_key if payment_key not in self.lnworker.active_forwardings: async def wrapped_callback(): forwarding_coro = forwarding_callback() @@ -2649,6 +2652,7 @@ class Peer(Logger): assert len(self.lnworker.active_forwardings[payment_key]) == 0 self.lnworker.save_forwarding_failure(payment_key, failure_message=e) # add to list + assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0 self.lnworker.active_forwardings[payment_key] = [] fut = asyncio.ensure_future(wrapped_callback()) # return payment_key so this branch will not be executed again diff --git a/electrum/lnutil.py b/electrum/lnutil.py index 6e587089a..d7de3f9ed 100644 --- a/electrum/lnutil.py +++ b/electrum/lnutil.py @@ -66,13 +66,15 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v) -def serialize_htlc_key(scid:bytes, htlc_id: int): +def serialize_htlc_key(scid: bytes, htlc_id: int) -> str: return scid.hex() + ':%d'%htlc_id -def deserialize_htlc_key(htlc_key:str): + +def deserialize_htlc_key(htlc_key: str) -> Tuple[bytes, int]: scid, htlc_id = htlc_key.split(':') return bytes.fromhex(scid), int(htlc_id) + @attr.s class OnlyPubkeyKeypair(StoredObject): pubkey = attr.ib(type=bytes, converter=hex_to_bytes) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 4e08de160..1489cf2a3 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -691,7 +691,7 @@ class PaySession(Logger): self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) self._nhtlcs_inflight = 0 - self.is_active = True + self.is_active = True # is still trying to send new htlcs? def diagnostic_name(self): pkey = sha256(self.payment_key) @@ -779,6 +779,7 @@ class PaySession(Logger): return self.amount_to_pay - self._amount_inflight def can_be_deleted(self) -> bool: + """Returns True iff finished sending htlcs AND all pending htlcs have resolved.""" if self.is_active: return False # note: no one is consuming from sent_htlcs_q anymore @@ -842,9 +843,9 @@ class LNWallet(LNWorker): self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) # payment forwarding - self.active_forwardings = self.db.get_dict('active_forwardings') # Dict: payment_key -> list of htlc_keys - self.forwarding_failures = self.db.get_dict('forwarding_failures') # Dict: payment_key -> (error_bytes, error_message) - self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted) + self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys + self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message) + self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted) # payment_hash -> callback: self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]] @@ -1222,20 +1223,28 @@ class LNWallet(LNWorker): self.logger.info('REBROADCASTING CLOSING TX') await self.network.try_broadcasting(force_close_tx, 'force-close') - def get_peer_by_scid_alias(self, scid_alias): + def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]: for nodeid, peer in self.peers.items(): if scid_alias == self._scid_alias_of_node(nodeid): return peer - def _scid_alias_of_node(self, nodeid): + def _scid_alias_of_node(self, nodeid: bytes) -> bytes: # scid alias for just-in-time channels return sha256(b'Electrum' + nodeid)[0:8] - def get_scid_alias(self): + def get_scid_alias(self) -> bytes: return self._scid_alias_of_node(self.node_keypair.pubkey) @log_exceptions - async def open_channel_just_in_time(self, next_peer, next_amount_msat_htlc, next_cltv_abs, payment_hash, next_onion): + async def open_channel_just_in_time( # FIXME xxxxx kwargs + self, + *, + next_peer: Peer, + next_amount_msat_htlc: int, + next_cltv_abs: int, + payment_hash: bytes, + next_onion: OnionPacket, + ) -> str: # if an exception is raised during negotiation, we raise an OnionRoutingFailure. # this will cancel the incoming HTLC try: @@ -2396,7 +2405,7 @@ class LNWallet(LNWorker): if htlc_key in htlcs: return payment_key - def notify_upstream_peer(self, htlc_key): + def notify_upstream_peer(self, htlc_key: str) -> None: """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. If we find this was a forwarded HTLC, the upstream peer is notified. """ @@ -2510,7 +2519,7 @@ class LNWallet(LNWorker): if fw_key: paysession_active = False else: - self.logger.info(f"received unknown htlc_failed, probably from previous session") + self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})") key = payment_hash.hex() self.set_invoice_status(key, PR_UNPAID) util.trigger_callback('payment_failed', self.wallet, key, '') @@ -2522,7 +2531,7 @@ class LNWallet(LNWorker): self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message) self.notify_upstream_peer(htlc_key) else: - self.logger.info(f"waiting for other htlcs to fail") + self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})") def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None): """calculate routing hints (BOLT-11 'r' field)""" @@ -3094,7 +3103,7 @@ class LNWallet(LNWorker): failure_hex = failure_message.to_bytes().hex() if failure_message else None self.forwarding_failures[payment_key] = (error_hex, failure_hex) - def get_forwarding_failure(self, payment_key: str): + def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]: error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None)) error_bytes = bytes.fromhex(error_hex) if error_hex else None failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None