Browse Source

lnworker: clean-up sent_htlcs_q and sent_htlcs_info

- introduce SentHtlcInfo named tuple
  - some previously unnamed tuples are now much shorter:
    create_routes_for_payment no longer returns an 8-tuple!
- sent_htlcs_q (renamed from sent_htlcs), is now keyed on payment_hash+payment_secret
  (needed for proper trampoline forwarding)
master
SomberNight 2 years ago
parent
commit
afac158c80
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/lnpeer.py
  2. 137
      electrum/lnworker.py
  3. 85
      electrum/tests/test_lnpeer.py

2
electrum/lnpeer.py

@ -1742,6 +1742,8 @@ class Peer(Logger):
except OnionRoutingFailure as e: except OnionRoutingFailure as e:
raise raise
except PaymentFailure as e: except PaymentFailure as e:
self.logger.debug(
f"maybe_forward_trampoline. PaymentFailure for {payment_hash.hex()=}, {payment_secret.hex()=}: {e!r}")
# FIXME: adapt the error code # FIXME: adapt the error code
raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.UNKNOWN_NEXT_PEER, data=b'')

137
electrum/lnworker.py

@ -69,7 +69,7 @@ from .lnutil import (Outpoint, LNPeerAddr,
NoPathFound, InvalidGossipMsg) NoPathFound, InvalidGossipMsg)
from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures from .lnutil import ln_dummy_address, ln_compare_features, IncompatibleLightningFeatures
from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput from .transaction import PartialTxOutput, PartialTransaction, PartialTxInput
from .lnonion import OnionFailureCode, OnionRoutingFailure from .lnonion import OnionFailureCode, OnionRoutingFailure, OnionPacket
from .lnmsg import decode_msg from .lnmsg import decode_msg
from .i18n import _ from .i18n import _
from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_to_use, from .lnrouter import (RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_sane_to_use,
@ -181,6 +181,20 @@ class ReceivedMPPStatus(NamedTuple):
htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]] htlc_set: Set[Tuple[ShortChannelID, UpdateAddHtlc]]
SentHtlcKey = Tuple[bytes, ShortChannelID, int] # RHASH, scid, htlc_id
class SentHtlcInfo(NamedTuple):
route: LNPaymentRoute
payment_secret_orig: bytes
payment_secret_bucket: bytes
amount_msat: int
bucket_msat: int
amount_receiver_msat: int
trampoline_fee_level: Optional[int]
trampoline_route: Optional[LNPaymentRoute]
class ErrorAddingPeer(Exception): pass class ErrorAddingPeer(Exception): pass
@ -678,8 +692,8 @@ class LNWallet(LNWorker):
for channel_id, storage in channel_backups.items(): for channel_id, storage in channel_backups.items():
self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self) self._channel_backups[bfh(channel_id)] = ChannelBackup(storage, lnworker=self)
self.sent_htlcs = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]] self.sent_htlcs_q = defaultdict(asyncio.Queue) # type: Dict[bytes, asyncio.Queue[HtlcLog]]
self.sent_htlcs_info = dict() # (RHASH, scid, htlc_id) -> route, payment_secret, amount_msat, bucket_msat, trampoline_fee_level self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed) self.sent_buckets = dict() # payment_key -> (amount_sent, amount_failed)
self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus self.received_mpp_htlcs = dict() # type: Dict[bytes, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
@ -1268,7 +1282,8 @@ class LNWallet(LNWorker):
if fwd_trampoline_cltv_delta < 576: if fwd_trampoline_cltv_delta < 576:
raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'') raise OnionRoutingFailure(code=OnionFailureCode.TRAMPOLINE_EXPIRY_TOO_SOON, data=b'')
self.logs[payment_hash.hex()] = log = [] payment_key = payment_hash + payment_secret
self.logs[payment_hash.hex()] = log = [] # TODO incl payment_secret in key (re trampoline forwarding)
# when encountering trampoline forwarding difficulties in the legacy case, we # when encountering trampoline forwarding difficulties in the legacy case, we
# sometimes need to fall back to a single trampoline forwarder, at the expense # sometimes need to fall back to a single trampoline forwarder, at the expense
@ -1300,28 +1315,24 @@ class LNWallet(LNWorker):
channels=channels, channels=channels,
) )
# 2. send htlcs # 2. send htlcs
async for route, amount_msat, total_msat, amount_receiver_msat, cltv_delta, bucket_payment_secret, trampoline_onion, trampoline_route in routes: async for sent_htlc_info, cltv_delta, trampoline_onion in routes:
amount_inflight += amount_receiver_msat amount_inflight += sent_htlc_info.amount_receiver_msat
if amount_inflight > amount_to_pay: # safety belts if amount_inflight > amount_to_pay: # safety belts
raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}") raise Exception(f"amount_inflight={amount_inflight} > amount_to_pay={amount_to_pay}")
sent_htlc_info = sent_htlc_info._replace(trampoline_fee_level=self.trampoline_fee_level)
await self.pay_to_route( await self.pay_to_route(
route=route, sent_htlc_info=sent_htlc_info,
amount_msat=amount_msat,
total_msat=total_msat,
amount_receiver_msat=amount_receiver_msat,
payment_hash=payment_hash, payment_hash=payment_hash,
payment_secret=bucket_payment_secret,
min_cltv_expiry=cltv_delta, min_cltv_expiry=cltv_delta,
trampoline_onion=trampoline_onion, trampoline_onion=trampoline_onion,
trampoline_fee_level=self.trampoline_fee_level, )
trampoline_route=trampoline_route)
# invoice_status is triggered in self.set_invoice_status when it actally changes. # invoice_status is triggered in self.set_invoice_status when it actally changes.
# It is also triggered here to update progress for a lightning payment in the GUI # It is also triggered here to update progress for a lightning payment in the GUI
# (e.g. attempt counter) # (e.g. attempt counter)
util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT) util.trigger_callback('invoice_status', self.wallet, payment_hash.hex(), PR_INFLIGHT)
# 3. await a queue # 3. await a queue
self.logger.info(f"amount inflight {amount_inflight}") self.logger.info(f"amount inflight {amount_inflight}")
htlc_log = await self.sent_htlcs[payment_hash].get() htlc_log = await self.sent_htlcs_q[payment_key].get()
amount_inflight -= htlc_log.amount_msat amount_inflight -= htlc_log.amount_msat
if amount_inflight < 0: if amount_inflight < 0:
raise Exception(f"amount_inflight={amount_inflight} < 0") raise Exception(f"amount_inflight={amount_inflight} < 0")
@ -1394,48 +1405,44 @@ class LNWallet(LNWorker):
async def pay_to_route( async def pay_to_route(
self, *, self, *,
route: LNPaymentRoute, sent_htlc_info: SentHtlcInfo,
amount_msat: int,
total_msat: int,
amount_receiver_msat:int,
payment_hash: bytes, payment_hash: bytes,
payment_secret: bytes,
min_cltv_expiry: int, min_cltv_expiry: int,
trampoline_onion: bytes = None, trampoline_onion: bytes = None,
trampoline_fee_level: int, ) -> None:
trampoline_route: Optional[List]) -> None: """Sends a single HTLC."""
shi = sent_htlc_info
# send a single htlc del sent_htlc_info # just renamed
short_channel_id = route[0].short_channel_id short_channel_id = shi.route[0].short_channel_id
chan = self.get_channel_by_short_id(short_channel_id) chan = self.get_channel_by_short_id(short_channel_id)
assert chan, ShortChannelID(short_channel_id) assert chan, ShortChannelID(short_channel_id)
peer = self._peers.get(route[0].node_id) peer = self._peers.get(shi.route[0].node_id)
if not peer: if not peer:
raise PaymentFailure('Dropped peer') raise PaymentFailure('Dropped peer')
await peer.initialized await peer.initialized
htlc = peer.pay( htlc = peer.pay(
route=route, route=shi.route,
chan=chan, chan=chan,
amount_msat=amount_msat, amount_msat=shi.amount_msat,
total_msat=total_msat, total_msat=shi.bucket_msat,
payment_hash=payment_hash, payment_hash=payment_hash,
min_final_cltv_expiry=min_cltv_expiry, min_final_cltv_expiry=min_cltv_expiry,
payment_secret=payment_secret, payment_secret=shi.payment_secret_bucket,
trampoline_onion=trampoline_onion) trampoline_onion=trampoline_onion)
key = (payment_hash, short_channel_id, htlc.htlc_id) key = (payment_hash, short_channel_id, htlc.htlc_id)
self.sent_htlcs_info[key] = route, payment_secret, amount_msat, total_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route self.sent_htlcs_info[key] = shi
payment_key = payment_hash + payment_secret payment_key = payment_hash + shi.payment_secret_bucket
# if we sent MPP to a trampoline, add item to sent_buckets # if we sent MPP to a trampoline, add item to sent_buckets
if self.uses_trampoline() and amount_msat != total_msat: if self.uses_trampoline() and shi.amount_msat != shi.bucket_msat:
if payment_key not in self.sent_buckets: if payment_key not in self.sent_buckets:
self.sent_buckets[payment_key] = (0, 0) self.sent_buckets[payment_key] = (0, 0)
amount_sent, amount_failed = self.sent_buckets[payment_key] amount_sent, amount_failed = self.sent_buckets[payment_key]
amount_sent += amount_receiver_msat amount_sent += shi.amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed self.sent_buckets[payment_key] = amount_sent, amount_failed
if self.network.path_finder: if self.network.path_finder:
# add inflight htlcs to liquidity hints # add inflight htlcs to liquidity hints
self.network.path_finder.update_inflight_htlcs(route, add_htlcs=True) self.network.path_finder.update_inflight_htlcs(shi.route, add_htlcs=True)
util.trigger_callback('htlc_added', chan, htlc, SENT) util.trigger_callback('htlc_added', chan, htlc, SENT)
def handle_error_code_from_failed_htlc( def handle_error_code_from_failed_htlc(
@ -1633,7 +1640,7 @@ class LNWallet(LNWorker):
fwd_trampoline_onion=None, fwd_trampoline_onion=None,
full_path: LNPaymentPath = None, full_path: LNPaymentPath = None,
channels: Optional[Sequence[Channel]] = None, channels: Optional[Sequence[Channel]] = None,
) -> AsyncGenerator[Tuple[LNPaymentRoute, int], None]: ) -> AsyncGenerator[Tuple[SentHtlcInfo, int, Optional[OnionPacket]], None]:
"""Creates multiple routes for splitting a payment over the available """Creates multiple routes for splitting a payment over the available
private channels. private channels.
@ -1719,7 +1726,17 @@ class LNWallet(LNWorker):
node_features=trampoline_features) node_features=trampoline_features)
] ]
self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}') self.logger.info(f'adding route {part_amount_msat} {delta_fee} {margin}')
routes.append((route, part_amount_msat_with_fees, per_trampoline_amount_with_fees, part_amount_msat, per_trampoline_cltv_delta, per_trampoline_secret, trampoline_onion, trampoline_route)) shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_bucket=per_trampoline_secret,
amount_msat=part_amount_msat_with_fees,
bucket_msat=per_trampoline_amount_with_fees,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_route=trampoline_route,
)
routes.append((shi, per_trampoline_cltv_delta, trampoline_onion))
if per_trampoline_fees != 0: if per_trampoline_fees != 0:
self.logger.info('not enough margin to pay trampoline fee') self.logger.info('not enough margin to pay trampoline fee')
raise NoPathFound() raise NoPathFound()
@ -1741,7 +1758,17 @@ class LNWallet(LNWorker):
full_path=full_path, full_path=full_path,
) )
) )
routes.append((route, part_amount_msat, final_total_msat, part_amount_msat, min_cltv_expiry, payment_secret, fwd_trampoline_onion, None)) shi = SentHtlcInfo(
route=route,
payment_secret_orig=payment_secret,
payment_secret_bucket=payment_secret,
amount_msat=part_amount_msat,
bucket_msat=final_total_msat,
amount_receiver_msat=part_amount_msat,
trampoline_fee_level=None,
trampoline_route=None,
)
routes.append((shi, min_cltv_expiry, fwd_trampoline_onion))
except NoPathFound: except NoPathFound:
continue continue
for route in routes: for route in routes:
@ -2096,14 +2123,16 @@ class LNWallet(LNWorker):
def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int): def htlc_fulfilled(self, chan: Channel, payment_hash: bytes, htlc_id: int):
util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id) util.trigger_callback('htlc_fulfilled', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash) q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_key = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_key)
if q: if q:
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
htlc_log = HtlcLog( htlc_log = HtlcLog(
success=True, success=True,
route=route, route=shi.route,
amount_msat=amount_receiver_msat, amount_msat=shi.amount_receiver_msat,
trampoline_fee_level=trampoline_fee_level) trampoline_fee_level=shi.trampoline_fee_level)
q.put_nowait(htlc_log) q.put_nowait(htlc_log)
else: else:
key = payment_hash.hex() key = payment_hash.hex()
@ -2120,12 +2149,16 @@ class LNWallet(LNWorker):
util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id) util.trigger_callback('htlc_failed', payment_hash, chan, htlc_id)
self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id) self._on_maybe_forwarded_htlc_resolved(chan=chan, htlc_id=htlc_id)
q = self.sent_htlcs.get(payment_hash) q = None
if shi := self.sent_htlcs_info.get((payment_hash, chan.short_channel_id, htlc_id)):
payment_okey = payment_hash + shi.payment_secret_orig
q = self.sent_htlcs_q.get(payment_okey)
if q: if q:
# detect if it is part of a bucket # detect if it is part of a bucket
# if yes, wait until the bucket completely failed # if yes, wait until the bucket completely failed
key = (payment_hash, chan.short_channel_id, htlc_id) shi = self.sent_htlcs_info[(payment_hash, chan.short_channel_id, htlc_id)]
route, payment_secret, amount_msat, bucket_msat, amount_receiver_msat, trampoline_fee_level, trampoline_route = self.sent_htlcs_info[key] amount_receiver_msat = shi.amount_receiver_msat
route = shi.route
if error_bytes: if error_bytes:
# TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone? # TODO "decode_onion_error" might raise, catch and maybe blacklist/penalise someone?
try: try:
@ -2140,19 +2173,19 @@ class LNWallet(LNWorker):
self.logger.info(f"htlc_failed {failure_message}") self.logger.info(f"htlc_failed {failure_message}")
# check sent_buckets if we use trampoline # check sent_buckets if we use trampoline
payment_key = payment_hash + payment_secret payment_bkey = payment_hash + shi.payment_secret_bucket
if self.uses_trampoline() and payment_key in self.sent_buckets: if self.uses_trampoline() and payment_bkey in self.sent_buckets:
amount_sent, amount_failed = self.sent_buckets[payment_key] amount_sent, amount_failed = self.sent_buckets[payment_bkey]
amount_failed += amount_receiver_msat amount_failed += amount_receiver_msat
self.sent_buckets[payment_key] = amount_sent, amount_failed self.sent_buckets[payment_bkey] = amount_sent, amount_failed
if amount_sent != amount_failed: if amount_sent != amount_failed:
self.logger.info('bucket still active...') self.logger.info('bucket still active...')
return return
self.logger.info('bucket failed') self.logger.info('bucket failed')
amount_receiver_msat = amount_sent amount_receiver_msat = amount_sent
if trampoline_route: if shi.trampoline_route:
route = trampoline_route route = shi.trampoline_route
htlc_log = HtlcLog( htlc_log = HtlcLog(
success=False, success=False,
route=route, route=route,
@ -2160,7 +2193,7 @@ class LNWallet(LNWorker):
error_bytes=error_bytes, error_bytes=error_bytes,
failure_msg=failure_message, failure_msg=failure_message,
sender_idx=sender_idx, sender_idx=sender_idx,
trampoline_fee_level=trampoline_fee_level) trampoline_fee_level=shi.trampoline_fee_level)
q.put_nowait(htlc_log) q.put_nowait(htlc_log)
else: else:
self.logger.info(f"received unknown htlc_failed, probably from previous session") self.logger.info(f"received unknown htlc_failed, probably from previous session")

85
electrum/tests/test_lnpeer.py

@ -31,7 +31,7 @@ from electrum.lnutil import PaymentFailure, LnFeatures, HTLCOwner
from electrum.lnchannel import ChannelState, PeerState, Channel from electrum.lnchannel import ChannelState, PeerState, Channel
from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent from electrum.lnrouter import LNPathFinder, PathEdge, LNPathInconsistent
from electrum.channel_db import ChannelDB from electrum.channel_db import ChannelDB
from electrum.lnworker import LNWallet, NoPathFound from electrum.lnworker import LNWallet, NoPathFound, SentHtlcInfo
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from electrum import lnmsg from electrum import lnmsg
from electrum.logging import console_stderr_handler, Logger from electrum.logging import console_stderr_handler, Logger
@ -166,7 +166,7 @@ class MockLNWallet(Logger, EventListener, NetworkRetryManager[LNPeerAddr]):
self.enable_htlc_settle = True self.enable_htlc_settle = True
self.enable_htlc_forwarding = True self.enable_htlc_forwarding = True
self.received_mpp_htlcs = dict() self.received_mpp_htlcs = dict()
self.sent_htlcs = defaultdict(asyncio.Queue) self.sent_htlcs_q = defaultdict(asyncio.Queue)
self.sent_htlcs_info = dict() self.sent_htlcs_info = dict()
self.sent_buckets = defaultdict(set) self.sent_buckets = defaultdict(set)
self.trampoline_forwardings = set() self.trampoline_forwardings = set()
@ -740,7 +740,7 @@ class TestPeer(ElectrumTestCase):
with self.assertRaises(SuccessfulTest): with self.assertRaises(SuccessfulTest):
await f() await f()
async def _activate_trampoline(self, w): async def _activate_trampoline(self, w: MockLNWallet):
if w.network.channel_db: if w.network.channel_db:
w.network.channel_db.stop() w.network.channel_db.stop()
await w.network.channel_db.stopped_event.wait() await w.network.channel_db.stopped_event.wait()
@ -837,38 +837,44 @@ class TestPeer(ElectrumTestCase):
lnaddr2, pay_req2 = self.prepare_invoice(w2) lnaddr2, pay_req2 = self.prepare_invoice(w2)
lnaddr1, pay_req1 = self.prepare_invoice(w1) lnaddr1, pay_req1 = self.prepare_invoice(w1)
# create the htlc queues now (side-effecting defaultdict) # create the htlc queues now (side-effecting defaultdict)
q1 = w1.sent_htlcs[lnaddr2.paymenthash] q1 = w1.sent_htlcs_q[lnaddr2.paymenthash + lnaddr2.payment_secret]
q2 = w2.sent_htlcs[lnaddr1.paymenthash] q2 = w2.sent_htlcs_q[lnaddr1.paymenthash + lnaddr1.payment_secret]
# alice sends htlc BUT NOT COMMITMENT_SIGNED # alice sends htlc BUT NOT COMMITMENT_SIGNED
p1.maybe_send_commitment = lambda x: None p1.maybe_send_commitment = lambda x: None
route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0] route1 = (await w1.create_routes_from_invoice(lnaddr2.get_amount_msat(), decoded_invoice=lnaddr2))[0][0].route
amount_msat = lnaddr2.get_amount_msat() shi1 = SentHtlcInfo(
await w1.pay_to_route(
route=route1, route=route1,
amount_msat=amount_msat, payment_secret_orig=lnaddr2.payment_secret,
total_msat=amount_msat, payment_secret_bucket=lnaddr2.payment_secret,
amount_receiver_msat=amount_msat, amount_msat=lnaddr2.get_amount_msat(),
bucket_msat=lnaddr2.get_amount_msat(),
amount_receiver_msat=lnaddr2.get_amount_msat(),
trampoline_fee_level=None,
trampoline_route=None,
)
await w1.pay_to_route(
sent_htlc_info=shi1,
payment_hash=lnaddr2.paymenthash, payment_hash=lnaddr2.paymenthash,
min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(), min_cltv_expiry=lnaddr2.get_min_final_cltv_expiry(),
payment_secret=lnaddr2.payment_secret,
trampoline_fee_level=0,
trampoline_route=None,
) )
p1.maybe_send_commitment = _maybe_send_commitment1 p1.maybe_send_commitment = _maybe_send_commitment1
# bob sends htlc BUT NOT COMMITMENT_SIGNED # bob sends htlc BUT NOT COMMITMENT_SIGNED
p2.maybe_send_commitment = lambda x: None p2.maybe_send_commitment = lambda x: None
route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0] route2 = (await w2.create_routes_from_invoice(lnaddr1.get_amount_msat(), decoded_invoice=lnaddr1))[0][0].route
amount_msat = lnaddr1.get_amount_msat() shi2 = SentHtlcInfo(
await w2.pay_to_route(
route=route2, route=route2,
amount_msat=amount_msat, payment_secret_orig=lnaddr1.payment_secret,
total_msat=amount_msat, payment_secret_bucket=lnaddr1.payment_secret,
amount_receiver_msat=amount_msat, amount_msat=lnaddr1.get_amount_msat(),
bucket_msat=lnaddr1.get_amount_msat(),
amount_receiver_msat=lnaddr1.get_amount_msat(),
trampoline_fee_level=None,
trampoline_route=None,
)
await w2.pay_to_route(
sent_htlc_info=shi2,
payment_hash=lnaddr1.paymenthash, payment_hash=lnaddr1.paymenthash,
min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(), min_cltv_expiry=lnaddr1.get_min_final_cltv_expiry(),
payment_secret=lnaddr1.payment_secret,
trampoline_fee_level=0,
trampoline_route=None,
) )
p2.maybe_send_commitment = _maybe_send_commitment2 p2.maybe_send_commitment = _maybe_send_commitment2
# sleep a bit so that they both receive msgs sent so far # sleep a bit so that they both receive msgs sent so far
@ -878,9 +884,9 @@ class TestPeer(ElectrumTestCase):
p2.maybe_send_commitment(bob_channel) p2.maybe_send_commitment(bob_channel)
htlc_log1 = await q1.get() htlc_log1 = await q1.get()
assert htlc_log1.success self.assertTrue(htlc_log1.success)
htlc_log2 = await q2.get() htlc_log2 = await q2.get()
assert htlc_log2.success self.assertTrue(htlc_log2.success)
raise PaymentDone() raise PaymentDone()
async def f(): async def f():
@ -1184,10 +1190,7 @@ class TestPeer(ElectrumTestCase):
if not bob_forwarding: if not bob_forwarding:
graph.workers['bob'].enable_htlc_forwarding = False graph.workers['bob'].enable_htlc_forwarding = False
if alice_uses_trampoline: if alice_uses_trampoline:
if graph.workers['alice'].network.channel_db: await self._activate_trampoline(graph.workers['alice'])
graph.workers['alice'].network.channel_db.stop()
await graph.workers['alice'].network.channel_db.stopped_event.wait()
graph.workers['alice'].network.channel_db = None
else: else:
assert graph.workers['alice'].network.channel_db is not None assert graph.workers['alice'].network.channel_db is not None
lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay) lnaddr, pay_req = self.prepare_invoice(graph.workers['dave'], include_routing_hints=True, amount_msat=amount_to_pay)
@ -1433,7 +1436,7 @@ class TestPeer(ElectrumTestCase):
await util.wait_for2(p1.initialized, 1) await util.wait_for2(p1.initialized, 1)
await util.wait_for2(p2.initialized, 1) await util.wait_for2(p2.initialized, 1)
# alice sends htlc # alice sends htlc
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] route = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0].route
p1.pay(route=route, p1.pay(route=route,
chan=alice_channel, chan=alice_channel,
amount_msat=lnaddr.get_amount_msat(), amount_msat=lnaddr.get_amount_msat(),
@ -1556,7 +1559,8 @@ class TestPeer(ElectrumTestCase):
lnaddr, pay_req = self.prepare_invoice(w2) lnaddr, pay_req = self.prepare_invoice(w2)
lnaddr = w1._check_invoice(pay_req) lnaddr = w1._check_invoice(pay_req)
route, amount_msat = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0:2] shi = (await w1.create_routes_from_invoice(lnaddr.get_amount_msat(), decoded_invoice=lnaddr))[0][0]
route, amount_msat = shi.route, shi.amount_msat
assert amount_msat == lnaddr.get_amount_msat() assert amount_msat == lnaddr.get_amount_msat()
await w1.force_close_channel(alice_channel.channel_id) await w1.force_close_channel(alice_channel.channel_id)
@ -1570,20 +1574,21 @@ class TestPeer(ElectrumTestCase):
# AssertionError is ok since we shouldn't use old routes, and the # AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed # route finding should fail when channel is closed
async def f(): async def f():
min_cltv_expiry = lnaddr.get_min_final_cltv_expiry() shi = SentHtlcInfo(
payment_hash = lnaddr.paymenthash
payment_secret = lnaddr.payment_secret
pay = w1.pay_to_route(
route=route, route=route,
payment_secret_orig=lnaddr.payment_secret,
payment_secret_bucket=lnaddr.payment_secret,
amount_msat=amount_msat, amount_msat=amount_msat,
total_msat=amount_msat, bucket_msat=amount_msat,
amount_receiver_msat=amount_msat, amount_receiver_msat=amount_msat,
payment_hash=payment_hash, trampoline_fee_level=None,
payment_secret=payment_secret,
min_cltv_expiry=min_cltv_expiry,
trampoline_fee_level=0,
trampoline_route=None, trampoline_route=None,
) )
pay = w1.pay_to_route(
sent_htlc_info=shi,
payment_hash=lnaddr.paymenthash,
min_cltv_expiry=lnaddr.get_min_final_cltv_expiry(),
)
await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch()) await asyncio.gather(pay, p1._message_loop(), p2._message_loop(), p1.htlc_switch(), p2.htlc_switch())
with self.assertRaises(PaymentFailure): with self.assertRaises(PaymentFailure):
await f() await f()

Loading…
Cancel
Save