You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
235 lines
9.3 KiB
235 lines
9.3 KiB
import os |
|
import csv |
|
import io |
|
from typing import Callable, Tuple, Any, Dict, List, Sequence, Union, Optional |
|
|
|
|
|
class MalformedMsg(Exception): |
|
pass |
|
|
|
|
|
class UnknownMsgFieldType(MalformedMsg): |
|
pass |
|
|
|
|
|
class UnexpectedEndOfStream(MalformedMsg): |
|
pass |
|
|
|
|
|
def _assert_can_read_at_least_n_bytes(fd: io.BytesIO, n: int) -> None: |
|
cur_pos = fd.tell() |
|
end_pos = fd.seek(0, io.SEEK_END) |
|
fd.seek(cur_pos) |
|
if end_pos - cur_pos < n: |
|
raise UnexpectedEndOfStream(f"cur_pos={cur_pos}. end_pos={end_pos}. wants to read: {n}") |
|
|
|
|
|
def _read_field(*, fd: io.BytesIO, field_type: str, count: int) -> Union[bytes, int]: |
|
if not fd: raise Exception() |
|
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int" |
|
if count == 0: |
|
return b"" |
|
type_len = None |
|
if field_type == 'byte': |
|
type_len = 1 |
|
elif field_type == 'u16': |
|
type_len = 2 |
|
assert count == 1, count |
|
_assert_can_read_at_least_n_bytes(fd, type_len) |
|
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) |
|
elif field_type == 'u32': |
|
type_len = 4 |
|
assert count == 1, count |
|
_assert_can_read_at_least_n_bytes(fd, type_len) |
|
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) |
|
elif field_type == 'u64': |
|
type_len = 8 |
|
assert count == 1, count |
|
_assert_can_read_at_least_n_bytes(fd, type_len) |
|
return int.from_bytes(fd.read(type_len), byteorder="big", signed=False) |
|
# TODO tu16/tu32/tu64 |
|
elif field_type == 'chain_hash': |
|
type_len = 32 |
|
elif field_type == 'channel_id': |
|
type_len = 32 |
|
elif field_type == 'sha256': |
|
type_len = 32 |
|
elif field_type == 'signature': |
|
type_len = 64 |
|
elif field_type == 'point': |
|
type_len = 33 |
|
elif field_type == 'short_channel_id': |
|
type_len = 8 |
|
if type_len is None: |
|
raise UnknownMsgFieldType(f"unexpected field type: {field_type!r}") |
|
total_len = count * type_len |
|
_assert_can_read_at_least_n_bytes(fd, total_len) |
|
return fd.read(total_len) |
|
|
|
|
|
def _write_field(*, fd: io.BytesIO, field_type: str, count: int, |
|
value: Union[bytes, int]) -> None: |
|
if not fd: raise Exception() |
|
assert isinstance(count, int) and count >= 0, f"{count!r} must be non-neg int" |
|
if count == 0: |
|
return |
|
type_len = None |
|
if field_type == 'byte': |
|
type_len = 1 |
|
elif field_type == 'u16': |
|
type_len = 2 |
|
elif field_type == 'u32': |
|
type_len = 4 |
|
elif field_type == 'u64': |
|
type_len = 8 |
|
# TODO tu16/tu32/tu64 |
|
elif field_type == 'chain_hash': |
|
type_len = 32 |
|
elif field_type == 'channel_id': |
|
type_len = 32 |
|
elif field_type == 'sha256': |
|
type_len = 32 |
|
elif field_type == 'signature': |
|
type_len = 64 |
|
elif field_type == 'point': |
|
type_len = 33 |
|
elif field_type == 'short_channel_id': |
|
type_len = 8 |
|
if type_len is None: |
|
raise UnknownMsgFieldType(f"unexpected fundamental type: {field_type!r}") |
|
total_len = count * type_len |
|
if isinstance(value, int) and (count == 1 or field_type == 'byte'): |
|
value = int.to_bytes(value, length=total_len, byteorder="big", signed=False) |
|
if not isinstance(value, (bytes, bytearray)): |
|
raise Exception(f"can only write bytes into fd. got: {value!r}") |
|
if total_len != len(value): |
|
raise Exception(f"unexpected field size. expected: {total_len}, got {len(value)}") |
|
nbytes_written = fd.write(value) |
|
if nbytes_written != len(value): |
|
raise Exception(f"tried to write {len(value)} bytes, but only wrote {nbytes_written}!?") |
|
|
|
|
|
class LNSerializer: |
|
def __init__(self): |
|
self.msg_scheme_from_type = {} # type: Dict[bytes, List[Sequence[str]]] |
|
self.msg_type_from_name = {} # type: Dict[str, bytes] |
|
path = os.path.join(os.path.dirname(__file__), "lnwire", "peer_wire.csv") |
|
with open(path, newline='') as f: |
|
csvreader = csv.reader(f) |
|
for row in csvreader: |
|
#print(f">>> {row!r}") |
|
if row[0] == "msgtype": |
|
msg_type_name = row[1] |
|
msg_type_int = int(row[2]) |
|
msg_type_bytes = msg_type_int.to_bytes(2, 'big') |
|
assert msg_type_bytes not in self.msg_scheme_from_type, f"type collision? for {msg_type_name}" |
|
assert msg_type_name not in self.msg_type_from_name, f"type collision? for {msg_type_name}" |
|
row[2] = msg_type_int |
|
self.msg_scheme_from_type[msg_type_bytes] = [tuple(row)] |
|
self.msg_type_from_name[msg_type_name] = msg_type_bytes |
|
elif row[0] == "msgdata": |
|
assert msg_type_name == row[1] |
|
self.msg_scheme_from_type[msg_type_bytes].append(tuple(row)) |
|
else: |
|
pass # TODO |
|
|
|
def encode_msg(self, msg_type: str, **kwargs) -> bytes: |
|
""" |
|
Encode kwargs into a Lightning message (bytes) |
|
of the type given in the msg_type string |
|
""" |
|
#print(f">>> encode_msg. msg_type={msg_type}, payload={kwargs!r}") |
|
msg_type_bytes = self.msg_type_from_name[msg_type] |
|
scheme = self.msg_scheme_from_type[msg_type_bytes] |
|
with io.BytesIO() as fd: |
|
fd.write(msg_type_bytes) |
|
for row in scheme: |
|
if row[0] == "msgtype": |
|
pass |
|
elif row[0] == "msgdata": |
|
field_name = row[2] |
|
field_type = row[3] |
|
field_count_str = row[4] |
|
#print(f">>> encode_msg. msgdata. field_name={field_name!r}. field_type={field_type!r}. field_count_str={field_count_str!r}") |
|
if field_count_str == "": |
|
field_count = 1 |
|
else: |
|
try: |
|
field_count = int(field_count_str) |
|
except ValueError: |
|
field_count = kwargs[field_count_str] |
|
if isinstance(field_count, (bytes, bytearray)): |
|
field_count = int.from_bytes(field_count, byteorder="big") |
|
assert isinstance(field_count, int) |
|
try: |
|
field_value = kwargs[field_name] |
|
except KeyError: |
|
if len(row) > 5: |
|
break # optional feature field not present |
|
else: |
|
field_value = 0 # default mandatory fields to zero |
|
#print(f">>> encode_msg. writing field: {field_name}. value={field_value!r}. field_type={field_type!r}. count={field_count!r}") |
|
try: |
|
_write_field(fd=fd, |
|
field_type=field_type, |
|
count=field_count, |
|
value=field_value) |
|
#print(f">>> encode_msg. so far: {fd.getvalue().hex()}") |
|
except UnknownMsgFieldType as e: |
|
pass # TODO |
|
else: |
|
pass # TODO |
|
return fd.getvalue() |
|
|
|
def decode_msg(self, data: bytes) -> Tuple[str, dict]: |
|
""" |
|
Decode Lightning message by reading the first |
|
two bytes to determine message type. |
|
|
|
Returns message type string and parsed message contents dict |
|
""" |
|
#print(f"decode_msg >>> {data.hex()}") |
|
assert len(data) >= 2 |
|
msg_type_bytes = data[:2] |
|
msg_type_int = int.from_bytes(msg_type_bytes, byteorder="big", signed=False) |
|
scheme = self.msg_scheme_from_type[msg_type_bytes] |
|
assert scheme[0][2] == msg_type_int |
|
msg_type_name = scheme[0][1] |
|
parsed = {} |
|
with io.BytesIO(data[2:]) as fd: |
|
for row in scheme: |
|
#print(f"row: {row!r}") |
|
if row[0] == "msgtype": |
|
pass |
|
elif row[0] == "msgdata": |
|
field_name = row[2] |
|
field_type = row[3] |
|
field_count_str = row[4] |
|
if field_count_str == "": |
|
field_count = 1 |
|
else: |
|
try: |
|
field_count = int(field_count_str) |
|
except ValueError: |
|
field_count = parsed[field_count_str] |
|
assert isinstance(field_count, int) |
|
#print(f">> count={field_count}. parsed={parsed}") |
|
try: |
|
parsed[field_name] = _read_field(fd=fd, |
|
field_type=field_type, |
|
count=field_count) |
|
except UnknownMsgFieldType as e: |
|
pass # TODO |
|
except UnexpectedEndOfStream as e: |
|
if len(row) > 5: |
|
break # optional feature field not present |
|
else: |
|
raise |
|
else: |
|
pass # TODO |
|
return msg_type_name, parsed |
|
|
|
|
|
_inst = LNSerializer() |
|
encode_msg = _inst.encode_msg |
|
decode_msg = _inst.decode_msg
|
|
|