Browse Source

start asyncio loop in test_lnrouter and test_lnpeer

master
ThomasV 7 years ago
parent
commit
bfdf0a7e88
  1. 31
      electrum/tests/test_lnpeer.py
  2. 11
      electrum/tests/test_lnrouter.py

31
electrum/tests/test_lnpeer.py

@ -1,4 +1,3 @@
import unittest
import asyncio import asyncio
import tempfile import tempfile
from decimal import Decimal from decimal import Decimal
@ -11,7 +10,7 @@ from electrum.ecc import ECPrivkey
from electrum import simple_config, lnutil from electrum import simple_config, lnutil
from electrum.lnaddr import lnencode, LnAddr, lndecode from electrum.lnaddr import lnencode, LnAddr, lndecode
from electrum.bitcoin import COIN, sha256 from electrum.bitcoin import COIN, sha256
from electrum.util import bh2u, set_verbosity from electrum.util import bh2u, set_verbosity, create_and_start_event_loop
from electrum.lnpeer import Peer from electrum.lnpeer import Peer
from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey from electrum.lnutil import LNPeerAddr, Keypair, privkey_to_pubkey
from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving from electrum.lnutil import LightningPeerConnectionClosed, RemoteMisbehaving
@ -21,6 +20,7 @@ from electrum.lnworker import LNWorker
from electrum.lnmsg import encode_msg, decode_msg from electrum.lnmsg import encode_msg, decode_msg
from .test_lnchannel import create_test_channels from .test_lnchannel import create_test_channels
from . import SequentialTestCase
def keypair(): def keypair():
priv = ECPrivkey.generate_random_key().get_secret_bytes() priv = ECPrivkey.generate_random_key().get_secret_bytes()
@ -37,12 +37,12 @@ class MockNetwork:
def __init__(self, tx_queue): def __init__(self, tx_queue):
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.lnwatcher = None self.lnwatcher = None
self.interface = None
user_config = {} user_config = {}
user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-") user_dir = tempfile.mkdtemp(prefix="electrum-lnpeer-test-")
self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir) self.config = simple_config.SimpleConfig(user_config, read_user_dir_function=lambda: user_dir)
self.asyncio_loop = asyncio.get_event_loop() self.asyncio_loop = asyncio.get_event_loop()
self.channel_db = ChannelDB(self) self.channel_db = ChannelDB(self)
self.interface = None
self.path_finder = LNPathFinder(self.channel_db) self.path_finder = LNPathFinder(self.channel_db)
self.tx_queue = tx_queue self.tx_queue = tx_queue
@ -159,14 +159,23 @@ def transport_pair(name1, name2):
t2.other_mock_transport = t1 t2.other_mock_transport = t1
return t1, t2 return t1, t2
class TestPeer(unittest.TestCase): class TestPeer(SequentialTestCase):
@staticmethod
def setUpClass(): @classmethod
def setUpClass(cls):
super().setUpClass()
set_verbosity(True) set_verbosity(True)
def setUp(self): def setUp(self):
super().setUp()
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
self.alice_channel, self.bob_channel = create_test_channels() self.alice_channel, self.bob_channel = create_test_channels()
def tearDown(self):
super().tearDown()
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
self._loop_thread.join(timeout=1)
def test_require_data_loss_protect(self): def test_require_data_loss_protect(self):
mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None) mock_lnworker = MockLNWorker(keypair(), keypair(), self.alice_channel, tx_queue=None)
mock_transport = NoFeaturesTransport('') mock_transport = NoFeaturesTransport('')
@ -232,8 +241,10 @@ class TestPeer(unittest.TestCase):
self.assertEqual(await fut, 'Payment received') self.assertEqual(await fut, 'Payment received')
gath.cancel() gath.cancel()
gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop()) gath = asyncio.gather(pay(), p1._message_loop(), p2._message_loop())
async def f():
await gath
with self.assertRaises(asyncio.CancelledError): with self.assertRaises(asyncio.CancelledError):
run(gath) run(f())
def test_channel_usage_after_closing(self): def test_channel_usage_after_closing(self):
p1, p2, w1, w2, q1, q2 = self.prepare_peers() p1, p2, w1, w2, q1, q2 = self.prepare_peers()
@ -253,8 +264,10 @@ class TestPeer(unittest.TestCase):
peer = w1.peers[route[0].node_id] peer = w1.peers[route[0].node_id]
# AssertionError is ok since we shouldn't use old routes, and the # AssertionError is ok since we shouldn't use old routes, and the
# route finding should fail when channel is closed # route finding should fail when channel is closed
async def f():
await asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop())
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
run(asyncio.gather(w1._pay_to_route(route, addr, pay_req), p1._message_loop(), p2._message_loop())) run(f())
def run(coro): def run(coro):
return asyncio.get_event_loop().run_until_complete(coro) return asyncio.run_coroutine_threadsafe(coro, loop=asyncio.get_event_loop()).result()

11
electrum/tests/test_lnrouter.py

@ -3,7 +3,7 @@ import tempfile
import shutil import shutil
import asyncio import asyncio
from electrum.util import bh2u, bfh from electrum.util import bh2u, bfh, create_and_start_event_loop
from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet, OnionPerHop, from electrum.lnonion import (OnionHopsDataSingle, new_onion_packet, OnionPerHop,
process_onion_packet, _decode_onion_error, decode_onion_error, process_onion_packet, _decode_onion_error, decode_onion_error,
OnionFailureCode) OnionFailureCode)
@ -34,11 +34,20 @@ class Test_LNRouter(TestCaseForTestnet):
cls.electrum_path = tempfile.mkdtemp() cls.electrum_path = tempfile.mkdtemp()
cls.config = SimpleConfig({'electrum_path': cls.electrum_path}) cls.config = SimpleConfig({'electrum_path': cls.electrum_path})
def setUp(self):
super().setUp()
self.asyncio_loop, self._stop_loop, self._loop_thread = create_and_start_event_loop()
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
super().tearDownClass() super().tearDownClass()
shutil.rmtree(cls.electrum_path) shutil.rmtree(cls.electrum_path)
def tearDown(self):
super().tearDown()
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
self._loop_thread.join(timeout=1)
def test_find_path_for_payment(self): def test_find_path_for_payment(self):
class fake_network: class fake_network:
config = self.config config = self.config

Loading…
Cancel
Save