Browse Source

lnworker: add/fix some type hints, add some comments

follow-up recent refactor
master
SomberNight 2 years ago
parent
commit
4cdd199f5b
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 34
      electrum/lnpeer.py
  2. 6
      electrum/lnutil.py
  3. 33
      electrum/lnworker.py

34
electrum/lnpeer.py

@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
import asyncio
import os
import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable
from datetime import datetime
import functools
@ -1693,7 +1693,8 @@ class Peer(Logger):
self, *,
incoming_chan: Channel,
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]:
processed_onion: ProcessedOnionPacket,
) -> str:
# Forward HTLC
# FIXME: there are critical safety checks MISSING here
@ -1744,11 +1745,11 @@ class Peer(Logger):
break
else:
return await self.lnworker.open_channel_just_in_time(
next_peer,
next_amount_msat_htlc,
next_cltv_abs,
htlc.payment_hash,
processed_onion.next_packet)
next_peer=next_peer,
next_amount_msat_htlc=next_amount_msat_htlc,
next_cltv_abs=next_cltv_abs,
payment_hash=htlc.payment_hash,
next_onion=processed_onion.next_packet)
local_height = chain.height()
if next_chan is None:
@ -1815,7 +1816,8 @@ class Peer(Logger):
inc_cltv_abs: int,
outer_onion: ProcessedOnionPacket,
trampoline_onion: ProcessedOnionPacket,
fw_payment_key: str):
fw_payment_key: str,
) -> None:
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
@ -1905,11 +1907,11 @@ class Peer(Logger):
trampoline_onion=next_trampoline_onion,
)
await self.lnworker.open_channel_just_in_time(
next_peer,
amt_to_forward,
cltv_abs,
payment_hash,
next_onion)
next_peer=next_peer,
next_amount_msat_htlc=amt_to_forward,
next_cltv_abs=cltv_abs,
payment_hash=payment_hash,
next_onion=next_onion)
return
try:
@ -1957,8 +1959,8 @@ class Peer(Logger):
htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket,
onion_packet_bytes: bytes,
already_forwarded = False,
) -> Tuple[Optional[bytes], Optional[Callable]]:
already_forwarded: bool = False,
) -> Tuple[Optional[str], Optional[bytes], Optional[Callable[[], Awaitable[Optional[str]]]]]:
"""
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
Return (payment_key, preimage, callback) with at most a single element of the last two not None
@ -2637,6 +2639,7 @@ class Peer(Logger):
# HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding:
return None, None, None
assert payment_key
if payment_key not in self.lnworker.active_forwardings:
async def wrapped_callback():
forwarding_coro = forwarding_callback()
@ -2649,6 +2652,7 @@ class Peer(Logger):
assert len(self.lnworker.active_forwardings[payment_key]) == 0
self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
# add to list
assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0
self.lnworker.active_forwardings[payment_key] = []
fut = asyncio.ensure_future(wrapped_callback())
# return payment_key so this branch will not be executed again

6
electrum/lnutil.py

@ -66,13 +66,15 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
def serialize_htlc_key(scid:bytes, htlc_id: int):
def serialize_htlc_key(scid: bytes, htlc_id: int) -> str:
return scid.hex() + ':%d'%htlc_id
def deserialize_htlc_key(htlc_key:str):
def deserialize_htlc_key(htlc_key: str) -> Tuple[bytes, int]:
scid, htlc_id = htlc_key.split(':')
return bytes.fromhex(scid), int(htlc_id)
@attr.s
class OnlyPubkeyKeypair(StoredObject):
pubkey = attr.ib(type=bytes, converter=hex_to_bytes)

33
electrum/lnworker.py

@ -691,7 +691,7 @@ class PaySession(Logger):
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0
self.is_active = True
self.is_active = True # is still trying to send new htlcs?
def diagnostic_name(self):
pkey = sha256(self.payment_key)
@ -779,6 +779,7 @@ class PaySession(Logger):
return self.amount_to_pay - self._amount_inflight
def can_be_deleted(self) -> bool:
"""Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
if self.is_active:
return False
# note: no one is consuming from sent_htlcs_q anymore
@ -842,9 +843,9 @@ class LNWallet(LNWorker):
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
# payment forwarding
self.active_forwardings = self.db.get_dict('active_forwardings') # Dict: payment_key -> list of htlc_keys
self.forwarding_failures = self.db.get_dict('forwarding_failures') # Dict: payment_key -> (error_bytes, error_message)
self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted)
self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys
self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message)
self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted)
# payment_hash -> callback:
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
@ -1222,20 +1223,28 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close')
def get_peer_by_scid_alias(self, scid_alias):
def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items():
if scid_alias == self._scid_alias_of_node(nodeid):
return peer
def _scid_alias_of_node(self, nodeid):
def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
# scid alias for just-in-time channels
return sha256(b'Electrum' + nodeid)[0:8]
def get_scid_alias(self):
def get_scid_alias(self) -> bytes:
return self._scid_alias_of_node(self.node_keypair.pubkey)
@log_exceptions
async def open_channel_just_in_time(self, next_peer, next_amount_msat_htlc, next_cltv_abs, payment_hash, next_onion):
async def open_channel_just_in_time( # FIXME xxxxx kwargs
self,
*,
next_peer: Peer,
next_amount_msat_htlc: int,
next_cltv_abs: int,
payment_hash: bytes,
next_onion: OnionPacket,
) -> str:
# if an exception is raised during negotiation, we raise an OnionRoutingFailure.
# this will cancel the incoming HTLC
try:
@ -2396,7 +2405,7 @@ class LNWallet(LNWorker):
if htlc_key in htlcs:
return payment_key
def notify_upstream_peer(self, htlc_key):
def notify_upstream_peer(self, htlc_key: str) -> None:
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
If we find this was a forwarded HTLC, the upstream peer is notified.
"""
@ -2510,7 +2519,7 @@ class LNWallet(LNWorker):
if fw_key:
paysession_active = False
else:
self.logger.info(f"received unknown htlc_failed, probably from previous session")
self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
key = payment_hash.hex()
self.set_invoice_status(key, PR_UNPAID)
util.trigger_callback('payment_failed', self.wallet, key, '')
@ -2522,7 +2531,7 @@ class LNWallet(LNWorker):
self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
self.notify_upstream_peer(htlc_key)
else:
self.logger.info(f"waiting for other htlcs to fail")
self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
"""calculate routing hints (BOLT-11 'r' field)"""
@ -3094,7 +3103,7 @@ class LNWallet(LNWorker):
failure_hex = failure_message.to_bytes().hex() if failure_message else None
self.forwarding_failures[payment_key] = (error_hex, failure_hex)
def get_forwarding_failure(self, payment_key: str):
def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
error_bytes = bytes.fromhex(error_hex) if error_hex else None
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None

Loading…
Cancel
Save