Browse Source

lnworker: add/fix some type hints, add some comments

follow-up recent refactor
master
SomberNight 2 years ago
parent
commit
4cdd199f5b
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 34
      electrum/lnpeer.py
  2. 6
      electrum/lnutil.py
  3. 33
      electrum/lnworker.py

34
electrum/lnpeer.py

@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
import asyncio import asyncio
import os import os
import time import time
from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable from typing import Tuple, Dict, TYPE_CHECKING, Optional, Union, Set, Callable, Awaitable
from datetime import datetime from datetime import datetime
import functools import functools
@ -1693,7 +1693,8 @@ class Peer(Logger):
self, *, self, *,
incoming_chan: Channel, incoming_chan: Channel,
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket) -> Tuple[bytes, int]: processed_onion: ProcessedOnionPacket,
) -> str:
# Forward HTLC # Forward HTLC
# FIXME: there are critical safety checks MISSING here # FIXME: there are critical safety checks MISSING here
@ -1744,11 +1745,11 @@ class Peer(Logger):
break break
else: else:
return await self.lnworker.open_channel_just_in_time( return await self.lnworker.open_channel_just_in_time(
next_peer, next_peer=next_peer,
next_amount_msat_htlc, next_amount_msat_htlc=next_amount_msat_htlc,
next_cltv_abs, next_cltv_abs=next_cltv_abs,
htlc.payment_hash, payment_hash=htlc.payment_hash,
processed_onion.next_packet) next_onion=processed_onion.next_packet)
local_height = chain.height() local_height = chain.height()
if next_chan is None: if next_chan is None:
@ -1815,7 +1816,8 @@ class Peer(Logger):
inc_cltv_abs: int, inc_cltv_abs: int,
outer_onion: ProcessedOnionPacket, outer_onion: ProcessedOnionPacket,
trampoline_onion: ProcessedOnionPacket, trampoline_onion: ProcessedOnionPacket,
fw_payment_key: str): fw_payment_key: str,
) -> None:
forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS forwarding_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS
forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS forwarding_trampoline_enabled = self.network.config.EXPERIMENTAL_LN_FORWARD_TRAMPOLINE_PAYMENTS
@ -1905,11 +1907,11 @@ class Peer(Logger):
trampoline_onion=next_trampoline_onion, trampoline_onion=next_trampoline_onion,
) )
await self.lnworker.open_channel_just_in_time( await self.lnworker.open_channel_just_in_time(
next_peer, next_peer=next_peer,
amt_to_forward, next_amount_msat_htlc=amt_to_forward,
cltv_abs, next_cltv_abs=cltv_abs,
payment_hash, payment_hash=payment_hash,
next_onion) next_onion=next_onion)
return return
try: try:
@ -1957,8 +1959,8 @@ class Peer(Logger):
htlc: UpdateAddHtlc, htlc: UpdateAddHtlc,
processed_onion: ProcessedOnionPacket, processed_onion: ProcessedOnionPacket,
onion_packet_bytes: bytes, onion_packet_bytes: bytes,
already_forwarded = False, already_forwarded: bool = False,
) -> Tuple[Optional[bytes], Optional[Callable]]: ) -> Tuple[Optional[str], Optional[bytes], Optional[Callable[[], Awaitable[Optional[str]]]]]:
""" """
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded. Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
Return (payment_key, preimage, callback) with at most a single element of the last two not None Return (payment_key, preimage, callback) with at most a single element of the last two not None
@ -2637,6 +2639,7 @@ class Peer(Logger):
# HTLC we are supposed to forward, but haven't forwarded yet # HTLC we are supposed to forward, but haven't forwarded yet
if not self.lnworker.enable_htlc_forwarding: if not self.lnworker.enable_htlc_forwarding:
return None, None, None return None, None, None
assert payment_key
if payment_key not in self.lnworker.active_forwardings: if payment_key not in self.lnworker.active_forwardings:
async def wrapped_callback(): async def wrapped_callback():
forwarding_coro = forwarding_callback() forwarding_coro = forwarding_callback()
@ -2649,6 +2652,7 @@ class Peer(Logger):
assert len(self.lnworker.active_forwardings[payment_key]) == 0 assert len(self.lnworker.active_forwardings[payment_key]) == 0
self.lnworker.save_forwarding_failure(payment_key, failure_message=e) self.lnworker.save_forwarding_failure(payment_key, failure_message=e)
# add to list # add to list
assert len(self.lnworker.active_forwardings.get(payment_key, [])) == 0
self.lnworker.active_forwardings[payment_key] = [] self.lnworker.active_forwardings[payment_key] = []
fut = asyncio.ensure_future(wrapped_callback()) fut = asyncio.ensure_future(wrapped_callback())
# return payment_key so this branch will not be executed again # return payment_key so this branch will not be executed again

6
electrum/lnutil.py

@ -66,13 +66,15 @@ hex_to_bytes = lambda v: v if isinstance(v, bytes) else bytes.fromhex(v) if v is
json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v) json_to_keypair = lambda v: v if isinstance(v, OnlyPubkeyKeypair) else Keypair(**v) if len(v)==2 else OnlyPubkeyKeypair(**v)
def serialize_htlc_key(scid:bytes, htlc_id: int): def serialize_htlc_key(scid: bytes, htlc_id: int) -> str:
return scid.hex() + ':%d'%htlc_id return scid.hex() + ':%d'%htlc_id
def deserialize_htlc_key(htlc_key:str):
def deserialize_htlc_key(htlc_key: str) -> Tuple[bytes, int]:
scid, htlc_id = htlc_key.split(':') scid, htlc_id = htlc_key.split(':')
return bytes.fromhex(scid), int(htlc_id) return bytes.fromhex(scid), int(htlc_id)
@attr.s @attr.s
class OnlyPubkeyKeypair(StoredObject): class OnlyPubkeyKeypair(StoredObject):
pubkey = attr.ib(type=bytes, converter=hex_to_bytes) pubkey = attr.ib(type=bytes, converter=hex_to_bytes)

33
electrum/lnworker.py

@ -691,7 +691,7 @@ class PaySession(Logger):
self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees) self._amount_inflight = 0 # what we sent in htlcs (that receiver gets, without fees)
self._nhtlcs_inflight = 0 self._nhtlcs_inflight = 0
self.is_active = True self.is_active = True # is still trying to send new htlcs?
def diagnostic_name(self): def diagnostic_name(self):
pkey = sha256(self.payment_key) pkey = sha256(self.payment_key)
@ -779,6 +779,7 @@ class PaySession(Logger):
return self.amount_to_pay - self._amount_inflight return self.amount_to_pay - self._amount_inflight
def can_be_deleted(self) -> bool: def can_be_deleted(self) -> bool:
"""Returns True iff finished sending htlcs AND all pending htlcs have resolved."""
if self.is_active: if self.is_active:
return False return False
# note: no one is consuming from sent_htlcs_q anymore # note: no one is consuming from sent_htlcs_q anymore
@ -842,9 +843,9 @@ class LNWallet(LNWorker):
self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT) self.set_invoice_status(payment_hash.hex(), PR_INFLIGHT)
# payment forwarding # payment forwarding
self.active_forwardings = self.db.get_dict('active_forwardings') # Dict: payment_key -> list of htlc_keys self.active_forwardings = self.db.get_dict('active_forwardings') # type: Dict[str, List[str]] # Dict: payment_key -> list of htlc_keys
self.forwarding_failures = self.db.get_dict('forwarding_failures') # Dict: payment_key -> (error_bytes, error_message) self.forwarding_failures = self.db.get_dict('forwarding_failures') # type: Dict[str, Tuple[str, str]] # Dict: payment_key -> (error_bytes, error_message)
self.downstream_to_upstream_htlc = {} # Dict: htlc_key -> htlc_key (not persisted) self.downstream_to_upstream_htlc = {} # type: Dict[str, str] # Dict: htlc_key -> htlc_key (not persisted)
# payment_hash -> callback: # payment_hash -> callback:
self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]] self.hold_invoice_callbacks = {} # type: Dict[bytes, Callable[[bytes], Awaitable[None]]]
@ -1222,20 +1223,28 @@ class LNWallet(LNWorker):
self.logger.info('REBROADCASTING CLOSING TX') self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close') await self.network.try_broadcasting(force_close_tx, 'force-close')
def get_peer_by_scid_alias(self, scid_alias): def get_peer_by_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items(): for nodeid, peer in self.peers.items():
if scid_alias == self._scid_alias_of_node(nodeid): if scid_alias == self._scid_alias_of_node(nodeid):
return peer return peer
def _scid_alias_of_node(self, nodeid): def _scid_alias_of_node(self, nodeid: bytes) -> bytes:
# scid alias for just-in-time channels # scid alias for just-in-time channels
return sha256(b'Electrum' + nodeid)[0:8] return sha256(b'Electrum' + nodeid)[0:8]
def get_scid_alias(self): def get_scid_alias(self) -> bytes:
return self._scid_alias_of_node(self.node_keypair.pubkey) return self._scid_alias_of_node(self.node_keypair.pubkey)
@log_exceptions @log_exceptions
async def open_channel_just_in_time(self, next_peer, next_amount_msat_htlc, next_cltv_abs, payment_hash, next_onion): async def open_channel_just_in_time( # FIXME xxxxx kwargs
self,
*,
next_peer: Peer,
next_amount_msat_htlc: int,
next_cltv_abs: int,
payment_hash: bytes,
next_onion: OnionPacket,
) -> str:
# if an exception is raised during negotiation, we raise an OnionRoutingFailure. # if an exception is raised during negotiation, we raise an OnionRoutingFailure.
# this will cancel the incoming HTLC # this will cancel the incoming HTLC
try: try:
@ -2396,7 +2405,7 @@ class LNWallet(LNWorker):
if htlc_key in htlcs: if htlc_key in htlcs:
return payment_key return payment_key
def notify_upstream_peer(self, htlc_key): def notify_upstream_peer(self, htlc_key: str) -> None:
"""Called when an HTLC we offered on chan gets irrevocably fulfilled or failed. """Called when an HTLC we offered on chan gets irrevocably fulfilled or failed.
If we find this was a forwarded HTLC, the upstream peer is notified. If we find this was a forwarded HTLC, the upstream peer is notified.
""" """
@ -2510,7 +2519,7 @@ class LNWallet(LNWorker):
if fw_key: if fw_key:
paysession_active = False paysession_active = False
else: else:
self.logger.info(f"received unknown htlc_failed, probably from previous session") self.logger.info(f"received unknown htlc_failed, probably from previous session (phash={payment_hash.hex()})")
key = payment_hash.hex() key = payment_hash.hex()
self.set_invoice_status(key, PR_UNPAID) self.set_invoice_status(key, PR_UNPAID)
util.trigger_callback('payment_failed', self.wallet, key, '') util.trigger_callback('payment_failed', self.wallet, key, '')
@ -2522,7 +2531,7 @@ class LNWallet(LNWorker):
self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message) self.save_forwarding_failure(fw_key, error_bytes=error_bytes, failure_message=failure_message)
self.notify_upstream_peer(htlc_key) self.notify_upstream_peer(htlc_key)
else: else:
self.logger.info(f"waiting for other htlcs to fail") self.logger.info(f"waiting for other htlcs to fail (phash={payment_hash.hex()})")
def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None): def calc_routing_hints_for_invoice(self, amount_msat: Optional[int], channels=None):
"""calculate routing hints (BOLT-11 'r' field)""" """calculate routing hints (BOLT-11 'r' field)"""
@ -3094,7 +3103,7 @@ class LNWallet(LNWorker):
failure_hex = failure_message.to_bytes().hex() if failure_message else None failure_hex = failure_message.to_bytes().hex() if failure_message else None
self.forwarding_failures[payment_key] = (error_hex, failure_hex) self.forwarding_failures[payment_key] = (error_hex, failure_hex)
def get_forwarding_failure(self, payment_key: str): def get_forwarding_failure(self, payment_key: str) -> Tuple[Optional[bytes], Optional['OnionRoutingFailure']]:
error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None)) error_hex, failure_hex = self.forwarding_failures.get(payment_key, (None, None))
error_bytes = bytes.fromhex(error_hex) if error_hex else None error_bytes = bytes.fromhex(error_hex) if error_hex else None
failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None failure_message = OnionRoutingFailure.from_bytes(bytes.fromhex(failure_hex)) if failure_hex else None

Loading…
Cancel
Save