You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
6.6 KiB
135 lines
6.6 KiB
import asyncio |
|
|
|
import electrum_ecc as ecc |
|
|
|
from electrum import util |
|
from electrum.lntransport import LNPeerAddr, LNResponderTransport, LNTransport, extract_nodeid, split_host_port, ConnStringFormatError |
|
from electrum.util import OldTaskGroup |
|
|
|
from . import ElectrumTestCase |
|
from .test_bitcoin import needs_test_with_all_chacha20_implementations |
|
|
|
|
|
class TestLNTransport(ElectrumTestCase): |
|
|
|
@needs_test_with_all_chacha20_implementations |
|
async def test_responder(self): |
|
# local static |
|
ls_priv=bytes.fromhex('2121212121212121212121212121212121212121212121212121212121212121') |
|
# ephemeral |
|
e_priv=bytes.fromhex('2222222222222222222222222222222222222222222222222222222222222222') |
|
|
|
class Writer: |
|
def __init__(self): |
|
self.state = 0 |
|
def write(self, data): |
|
assert self.state == 0 |
|
self.state += 1 |
|
assert len(data) == 50 |
|
class Reader: |
|
def __init__(self): |
|
self.state = 0 |
|
async def read(self, num_bytes): |
|
assert self.state in (0, 1) |
|
self.state += 1 |
|
if self.state-1 == 0: |
|
assert num_bytes == 50 |
|
return bytes.fromhex('00036360e856310ce5d294e8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df6086551151f58b8afe6c195782c6a') |
|
elif self.state-1 == 1: |
|
assert num_bytes == 66 |
|
return bytes.fromhex('00b9e3a702e93e3a9948c2ed6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8fc28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba') |
|
transport = LNResponderTransport(ls_priv, Reader(), Writer()) |
|
await transport.handshake(epriv=e_priv) |
|
|
|
@needs_test_with_all_chacha20_implementations |
|
async def test_loop(self): |
|
responder_shaked = asyncio.Event() |
|
server_shaked = asyncio.Event() |
|
responder_key = ecc.ECPrivkey.generate_random_key() |
|
initiator_key = ecc.ECPrivkey.generate_random_key() |
|
messages_sent_by_client = [ |
|
b'hello from client', |
|
b'long data from client ' + bytes(range(256)) * 100 + b'... client done', |
|
b'client is running out of things to say', |
|
] |
|
messages_sent_by_server = [ |
|
b'hello from server', |
|
b'hello2 from server', |
|
b'long data from server ' + bytes(range(256)) * 100 + b'... server done', |
|
] |
|
async def read_messages(transport, expected_messages): |
|
ctr = 0 |
|
async for msg in transport.read_messages(): |
|
self.assertEqual(expected_messages[ctr], msg) |
|
ctr += 1 |
|
if ctr == len(expected_messages): |
|
return |
|
async def write_messages(transport, expected_messages): |
|
for msg in expected_messages: |
|
transport.send_bytes(msg) |
|
await asyncio.sleep(0.01) |
|
|
|
async def cb(reader, writer): |
|
t = LNResponderTransport(responder_key.get_secret_bytes(), reader, writer) |
|
self.assertEqual(await t.handshake(), initiator_key.get_public_key_bytes()) |
|
async with OldTaskGroup() as group: |
|
await group.spawn(read_messages(t, messages_sent_by_client)) |
|
await group.spawn(write_messages(t, messages_sent_by_server)) |
|
responder_shaked.set() |
|
async def connect(port: int): |
|
peer_addr = LNPeerAddr('127.0.0.1', port, responder_key.get_public_key_bytes()) |
|
t = LNTransport(initiator_key.get_secret_bytes(), peer_addr, e_proxy=None) |
|
await t.handshake() |
|
async with OldTaskGroup() as group: |
|
await group.spawn(read_messages(t, messages_sent_by_server)) |
|
await group.spawn(write_messages(t, messages_sent_by_client)) |
|
server_shaked.set() |
|
|
|
async def f(): |
|
server = await asyncio.start_server(cb, '127.0.0.1', port=None) |
|
server_port = server.sockets[0].getsockname()[1] |
|
try: |
|
async with OldTaskGroup() as group: |
|
await group.spawn(connect(port=server_port)) |
|
await group.spawn(responder_shaked.wait()) |
|
await group.spawn(server_shaked.wait()) |
|
finally: |
|
server.close() |
|
|
|
await f() |
|
|
|
def test_split_host_port(self): |
|
self.assertEqual(split_host_port("[::1]:8000"), ("::1", "8000")) |
|
self.assertEqual(split_host_port("[::1]"), ("::1", "9735")) |
|
self.assertEqual(split_host_port("[2601:602:8800:9a:dc59:a4ff:fede:24a9]:9735"), ("2601:602:8800:9a:dc59:a4ff:fede:24a9", "9735")) |
|
self.assertEqual(split_host_port("[2601:602:8800::a4ff:fede:24a9]:9735"), ("2601:602:8800::a4ff:fede:24a9", "9735")) |
|
self.assertEqual(split_host_port("kæn.guru:8000"), ("kæn.guru", "8000")) |
|
self.assertEqual(split_host_port("kæn.guru"), ("kæn.guru", "9735")) |
|
self.assertEqual(split_host_port("127.0.0.1:8000"), ("127.0.0.1", "8000")) |
|
self.assertEqual(split_host_port("127.0.0.1"), ("127.0.0.1", "9735")) |
|
# accepted by getaddrinfo but not ipaddress.ip_address |
|
self.assertEqual(split_host_port("127.0.0:8000"), ("127.0.0", "8000")) |
|
self.assertEqual(split_host_port("127.0.0"), ("127.0.0", "9735")) |
|
self.assertEqual(split_host_port("electrum.org:8000"), ("electrum.org", "8000")) |
|
self.assertEqual(split_host_port("electrum.org"), ("electrum.org", "9735")) |
|
|
|
with self.assertRaises(ConnStringFormatError): |
|
split_host_port("electrum.org:8000:") |
|
with self.assertRaises(ConnStringFormatError): |
|
split_host_port("electrum.org:") |
|
|
|
def test_extract_nodeid(self): |
|
pubkey1 = ecc.GENERATOR.get_public_key_bytes(compressed=True) |
|
with self.assertRaises(ConnStringFormatError): |
|
extract_nodeid("00" * 32 + "@localhost") |
|
with self.assertRaises(ConnStringFormatError): |
|
extract_nodeid("00" * 33 + "@") |
|
# pubkey + host |
|
self.assertEqual(extract_nodeid("00" * 33 + "@localhost"), (b"\x00" * 33, "localhost")) |
|
self.assertEqual(extract_nodeid(f"{pubkey1.hex()}@11.22.33.44"), (pubkey1, "11.22.33.44")) |
|
self.assertEqual(extract_nodeid(f"{pubkey1.hex()}@[2001:41d0:e:734::1]"), (pubkey1, "[2001:41d0:e:734::1]")) |
|
# pubkey + host + port |
|
self.assertEqual(extract_nodeid(f"{pubkey1.hex()}@11.22.33.44:5555"), (pubkey1, "11.22.33.44:5555")) |
|
self.assertEqual(extract_nodeid(f"{pubkey1.hex()}@[2001:41d0:e:734::1]:8888"), (pubkey1, "[2001:41d0:e:734::1]:8888")) |
|
# just pubkey |
|
self.assertEqual(extract_nodeid(f"{pubkey1.hex()}"), (pubkey1, None))
|
|
|