You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
302 lines
11 KiB
302 lines
11 KiB
import asyncio |
|
import tempfile |
|
from decimal import Decimal |
|
import os |
|
from contextlib import contextmanager |
|
from collections import defaultdict |
|
import logging |
|
import concurrent |
|
from concurrent import futures |
|
|
|
from electrum.network import Network |
|
from electrum.ecc import ECPrivkey |
|
from electrum import simple_config, lnutil |
|
from electrum.lnaddr import lnencode, LnAddr, lndecode |
|
from electrum.bitcoin import COIN, sha256 |
|
from electrum.util import bh2u, create_and_start_event_loop |
|
from electrum.lnpeer import Peer |
|
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey |
|
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving |
|
from electrum.lnutil import PaymentFailure, LnLocalFeatures |
|
from electrum.lnchannel import channel_states, peer_states, Channel |
|
from electrum.lnrouter import LNPathFinder |
|
from electrum.channel_db import ChannelDB |
|
from electrum.lnworker import LNWallet, NoPathFound |
|
from electrum.lnmsg import encode_msg, decode_msg |
|
from electrum.logging import console_stderr_handler |
|
from electrum.lnworker import PaymentInfo, RECEIVED, PR_UNPAID |
|
|
|
from .test_lnchannel import create_test_channels |
|
from . import ElectrumTestCase |
|
|
|
def keypair(): |
|
priv = ECPrivkey.generate_random_key().get_secret_bytes() |
|
k1 = Keypair( |
|
pubkey=privkey_to_pubkey(priv), |
|
privkey=priv) |
|
return k1 |
|
|
|
@contextmanager |
|
def noop_lock(): |
|
yield |
|
|
|
class MockNetwork: |
|
def __init__(self, tx_queue): |
|
self.callbacks = defaultdict(list) |
|
self.lnwatcher = None |
|
self.interface = None |
|
user_config = {} |
|
user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") |
|
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir) |
|
self.asyncio_loop = asyncio.get_event_loop() |
|
self.channel_db = ChannelDB(self) |
|
self.path_finder = LNPathFinder(self.channel_db) |
|
self.tx_queue = tx_queue |
|
|
|
@property |
|
def callback_lock(self): |
|
return noop_lock() |
|
|
|
register_callback = Network.register_callback |
|
unregister_callback = Network.unregister_callback |
|
trigger_callback = Network.trigger_callback |
|
|
|
def get_local_height(self): |
|
return 0 |
|
|
|
async def try_broadcasting(self, tx, name): |
|
if self.tx_queue: |
|
await self.tx_queue.put(tx) |
|
|
|
class MockWallet: |
|
def set_label(self, x, y): |
|
pass |
|
def save_db(self): |
|
pass |
|
def is_lightning_backup(self): |
|
return False |
|
|
|
class MockLNWallet: |
|
def __init__(self, remote_keypair, local_keypair, chan: 'Channel', tx_queue): |
|
self.remote_keypair = remote_keypair |
|
self.node_keypair = local_keypair |
|
self.network = MockNetwork(tx_queue) |
|
self.channels = {chan.channel_id: chan} |
|
self.payments = {} |
|
self.logs = defaultdict(list) |
|
self.wallet = MockWallet() |
|
self.localfeatures = LnLocalFeatures(0) |
|
self.localfeatures |= LnLocalFeatures.OPTION_DATA_LOSS_PROTECT_OPT |
|
self.pending_payments = defaultdict(asyncio.Future) |
|
chan.lnworker = self |
|
chan.node_id = remote_keypair.pubkey |
|
|
|
def get_invoice_status(self, key): |
|
pass |
|
|
|
@property |
|
def lock(self): |
|
return noop_lock() |
|
|
|
@property |
|
def peers(self): |
|
return {self.remote_keypair.pubkey: self.peer} |
|
|
|
def channels_for_peer(self, pubkey): |
|
return self.channels |
|
|
|
def get_channel_by_short_id(self, short_channel_id): |
|
with self.lock: |
|
for chan in self.channels.values(): |
|
if chan.short_channel_id == short_channel_id: |
|
return chan |
|
|
|
def save_channel(self, chan): |
|
print("Ignoring channel save") |
|
|
|
preimages = {} |
|
get_payment_info = LNWallet.get_payment_info |
|
save_payment_info = LNWallet.save_payment_info |
|
set_payment_status = LNWallet.set_payment_status |
|
get_payment_status = LNWallet.get_payment_status |
|
await_payment = LNWallet.await_payment |
|
payment_received = LNWallet.payment_received |
|
payment_sent = LNWallet.payment_sent |
|
save_preimage = LNWallet.save_preimage |
|
get_preimage = LNWallet.get_preimage |
|
_create_route_from_invoice = LNWallet._create_route_from_invoice |
|
_check_invoice = staticmethod(LNWallet._check_invoice) |
|
_pay_to_route = LNWallet._pay_to_route |
|
force_close_channel = LNWallet.force_close_channel |
|
get_first_timestamp = lambda self: 0 |
|
payment_completed = LNWallet.payment_completed |
|
|
|
class MockTransport: |
|
def __init__(self, name): |
|
self.queue = asyncio.Queue() |
|
self._name = name |
|
|
|
def name(self): |
|
return self._name |
|
|
|
async def read_messages(self): |
|
while True: |
|
yield await self.queue.get() |
|
|
|
class NoFeaturesTransport(MockTransport): |
|
""" |
|
This answers the init message with a init that doesn't signal any features. |
|
Used for testing that we require DATA_LOSS_PROTECT. |
|
""" |
|
def send_bytes(self, data): |
|
decoded = decode_msg(data) |
|
print(decoded) |
|
if decoded[0] == 'init': |
|
self.queue.put_nowait(encode_msg('init', lflen=1, gflen=1, localfeatures=b"\x00", globalfeatures=b"\x00")) |
|
|
|
class PutIntoOthersQueueTransport(MockTransport): |
|
def __init__(self, name): |
|
super().__init__(name) |
|
self.other_mock_transport = None |
|
|
|
def send_bytes(self, data): |
|
self.other_mock_transport.queue.put_nowait(data) |
|
|
|
def transport_pair(name1, name2): |
|
t1 = PutIntoOthersQueueTransport(name1) |
|
t2 = PutIntoOthersQueueTransport(name2) |
|
t1.other_mock_transport = t2 |
|
t2.other_mock_transport = t1 |
|
return t1, t2 |
|
|
|
class TestPeer(ElectrumTestCase): |
|
|
|
@classmethod |
|
def setUpClass(cls): |
|
super().setUpClass() |
|
console_stderr_handler.setLevel(logging.DEBUG) |
|
|
|
def setUp(self): |
|
super().setUp() |
|
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop() |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1) |
|
self._loop_thread.join(timeout=1) |
|
|
|
def prepare_peers(self, alice_channel, bob_channel): |
|
k1, k2 = keypair(), keypair() |
|
t1, t2 = transport_pair(alice_channel.name, bob_channel.name) |
|
q1, q2 = asyncio.Queue(), asyncio.Queue() |
|
w1 = MockLNWallet(k1, k2, alice_channel, tx_queue=q1) |
|
w2 = MockLNWallet(k2, k1, bob_channel, tx_queue=q2) |
|
p1 = Peer(w1, k1.pubkey, t1) |
|
p2 = Peer(w2, k2.pubkey, t2) |
|
w1.peer = p1 |
|
w2.peer = p2 |
|
# mark_open won't work if state is already OPEN. |
|
# so set it to FUNDED |
|
alice_channel._state = channel_states.FUNDED |
|
bob_channel._state = channel_states.FUNDED |
|
# this populates the channel graph: |
|
p1.mark_open(alice_channel) |
|
p2.mark_open(bob_channel) |
|
return p1, p2, w1, w2, q1, q2 |
|
|
|
@staticmethod |
|
def prepare_invoice(w2 # receiver |
|
): |
|
amount_sat = 100000 |
|
amount_btc = amount_sat/Decimal(COIN) |
|
payment_preimage = os.urandom(32) |
|
RHASH = sha256(payment_preimage) |
|
info = PaymentInfo(RHASH, amount_sat, RECEIVED, PR_UNPAID) |
|
w2.save_preimage(RHASH, payment_preimage) |
|
w2.save_payment_info(info) |
|
lnaddr = LnAddr( |
|
RHASH, |
|
amount_btc, |
|
tags=[('c', lnutil.MIN_FINAL_CLTV_EXPIRY_FOR_INVOICE), |
|
('d', 'coffee') |
|
]) |
|
return lnencode(lnaddr, w2.node_keypair.privkey) |
|
|
|
def test_reestablish(self): |
|
alice_channel, bob_channel = create_test_channels() |
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) |
|
async def reestablish(): |
|
await asyncio.gather( |
|
p1.reestablish_channel(alice_channel), |
|
p2.reestablish_channel(bob_channel)) |
|
self.assertEqual(alice_channel.peer_state, peer_states.GOOD) |
|
self.assertEqual(bob_channel.peer_state, peer_states.GOOD) |
|
gath.cancel() |
|
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) |
|
async def f(): |
|
await gath |
|
with self.assertRaises(concurrent.futures.CancelledError): |
|
run(f()) |
|
|
|
def test_reestablish_with_old_state(self): |
|
alice_channel, bob_channel = create_test_channels() |
|
alice_channel_0, bob_channel_0 = create_test_channels() # these are identical |
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) |
|
pay_req = self.prepare_invoice(w2) |
|
async def reestablish(): |
|
result = await LNWallet._pay(w1, pay_req) |
|
self.assertEqual(result, True) |
|
w1.channels = {alice_channel_0.channel_id: alice_channel_0} |
|
await asyncio.gather( |
|
p1.reestablish_channel(alice_channel_0), |
|
p2.reestablish_channel(bob_channel)) |
|
self.assertEqual(alice_channel_0.peer_state, peer_states.BAD) |
|
self.assertEqual(bob_channel._state, channel_states.FORCE_CLOSING) |
|
# wait so that pending messages are processed |
|
#await asyncio.sleep(1) |
|
gath.cancel() |
|
gath = asyncio.gather(reestablish(), p1._message_loop(), p2._message_loop()) |
|
async def f(): |
|
await gath |
|
with self.assertRaises(concurrent.futures.CancelledError): |
|
run(f()) |
|
|
|
def test_payment(self): |
|
alice_channel, bob_channel = create_test_channels() |
|
p1, p2, w1, w2, _q1, _q2 = self.prepare_peers(alice_channel, bob_channel) |
|
pay_req = self.prepare_invoice(w2) |
|
async def pay(): |
|
result = await LNWallet._pay(w1, pay_req) |
|
self.assertTrue(result) |
|
gath.cancel() |
|
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) |
|
async def f(): |
|
await gath |
|
with self.assertRaises(concurrent.futures.CancelledError): |
|
run(f()) |
|
|
|
def test_channel_usage_after_closing(self): |
|
alice_channel, bob_channel = create_test_channels() |
|
p1, p2, w1, w2, q1, q2 = self.prepare_peers(alice_channel, bob_channel) |
|
pay_req = self.prepare_invoice(w2) |
|
|
|
addr = w1._check_invoice(pay_req) |
|
route = run(w1._create_route_from_invoice(decoded_invoice=addr)) |
|
|
|
run(w1.force_close_channel(alice_channel.channel_id)) |
|
# check if a tx (commitment transaction) was broadcasted: |
|
assert q1.qsize() == 1 |
|
|
|
with self.assertRaises(NoPathFound) as e: |
|
run(w1._create_route_from_invoice(decoded_invoice=addr)) |
|
|
|
peer = w1.peers[route[0].node_id] |
|
# AssertionError is ok since we shouldn't use old routes, and the |
|
# route finding should fail when channel is closed |
|
async def f(): |
|
await asyncio.gather(w1._pay_to_route(route, addr), p1._message_loop(), p2._message_loop()) |
|
with self.assertRaises(PaymentFailure): |
|
run(f()) |
|
|
|
def run(coro): |
|
return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()
|
|
|