You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

99 lines
4.2 KiB

import asyncio
from electrum import util
from electrum.ecc import ECPrivkey
from electrum.lnutil import LNPeerAddr
from electrum.lntransport import LNResponderTransport, LNTransport
from electrum.util import OldTaskGroup
from . import ElectrumTestCase
from .test_bitcoin import needs_test_with_all_chacha20_implementations
class TestLNTransport(ElectrumTestCase):
@needs_test_with_all_chacha20_implementations
async def test_responder(self):
# local static
ls_priv=bytes.fromhex('2121212121212121212121212121212121212121212121212121212121212121')
# ephemeral
e_priv=bytes.fromhex('2222222222222222222222222222222222222222222222222222222222222222')
class Writer:
def __init__(self):
self.state = 0
def write(self, data):
assert self.state == 0
self.state += 1
assert len(data) == 50
class Reader:
def __init__(self):
self.state = 0
async def read(self, num_bytes):
assert self.state in (0, 1)
self.state += 1
if self.state-1 == 0:
assert num_bytes == 50
return bytes.fromhex('00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a')
elif self.state-1 == 1:
assert num_bytes == 66
return bytes.fromhex('00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba')
transport = LNResponderTransport(ls_priv, Reader(), Writer())
await transport.handshake(epriv=e_priv)
@needs_test_with_all_chacha20_implementations
async def test_loop(self):
responder_shaked = asyncio.Event()
server_shaked = asyncio.Event()
responder_key = ECPrivkey.generate_random_key()
initiator_key = ECPrivkey.generate_random_key()
messages_sent_by_client = [
b'hello from client',
b'long data from client ' + bytes(range(256)) * 100 + b'... client done',
b'client is running out of things to say',
]
messages_sent_by_server = [
b'hello from server',
b'hello2 from server',
b'long data from server ' + bytes(range(256)) * 100 + b'... server done',
]
async def read_messages(transport, expected_messages):
ctr = 0
async for msg in transport.read_messages():
self.assertEqual(expected_messages[ctr], msg)
ctr += 1
if ctr == len(expected_messages):
return
async def write_messages(transport, expected_messages):
for msg in expected_messages:
transport.send_bytes(msg)
await asyncio.sleep(0.01)
async def cb(reader, writer):
t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer)
self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes())
async with OldTaskGroup() as group:
await group.spawn(read_messages(t, messages_sent_by_client))
await group.spawn(write_messages(t, messages_sent_by_server))
responder_shaked.set()
async def connect(port: int):
peer_addr = LNPeerAddr('127.0.0.1', port, responder_key.get_public_key_bytes())
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, proxy=None)
await t.handshake()
async with OldTaskGroup() as group:
await group.spawn(read_messages(t, messages_sent_by_server))
await group.spawn(write_messages(t, messages_sent_by_client))
server_shaked.set()
async def f():
server = await asyncio.start_server(cb, '127.0.0.1', port=None)
server_port = server.sockets[0].getsockname()[1]
try:
async with OldTaskGroup() as group:
await group.spawn(connect(port=server_port))
await group.spawn(responder_shaked.wait())
await group.spawn(server_shaked.wait())
finally:
server.close()
await f()