From cf2ed509b4df1cf722004801543d804fc989696c Mon Sep 17 00:00:00 2001 From: SomberNight Date: Wed, 24 Apr 2024 14:10:01 +0000 Subject: [PATCH] dependencies: remove bitstring - `bitstring` started depending on `bitarray` in version 4.1 [0] - that would mean one additional dependency for us (from yet another maintainer), which is not even pure python - we only use bitstring for bolt11-parsing - hence this PR rewrites the bolt11-parsing and removes `bitstring` as dependency - note: I benchmarked lndecode using [1], and the new code performs better, taking around 80% time needed for old code (when using bitstring 3.1.9, pure python). Though the variance is quite large in both cases. [0]: https://github.com/scott-griffiths/bitstring/blob/95ee533ee4040b4480da1ead548eab2459e8e573/release_notes.txt#L108 [1]: https://github.com/spesmilo/electrum/commit/d7597d96d0c336838adb32e3e175d3ea6f9763e8 --- contrib/deterministic-build/requirements.txt | 2 - contrib/requirements/requirements.txt | 1 - electrum/lnaddr.py | 277 ++++++++++--------- electrum/segwit_addr.py | 8 +- electrum/trampoline.py | 56 ++-- tests/test_bolt11.py | 16 +- 6 files changed, 187 insertions(+), 173 deletions(-) diff --git a/contrib/deterministic-build/requirements.txt b/contrib/deterministic-build/requirements.txt index ba5f99249..686da50c8 100644 --- a/contrib/deterministic-build/requirements.txt +++ b/contrib/deterministic-build/requirements.txt @@ -10,8 +10,6 @@ async-timeout==4.0.3 \ --hash=sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f attrs==22.1.0 \ --hash=sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6 -bitstring==3.1.9 \ - --hash=sha256:a5848a3f63111785224dca8bb4c0a75b62ecdef56a042c8d6be74b16f7e860e7 certifi==2024.2.2 \ --hash=sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f dnspython==2.2.1 \ diff --git a/contrib/requirements/requirements.txt b/contrib/requirements/requirements.txt index db3abb134..484b2c592 100644 --- a/contrib/requirements/requirements.txt +++ b/contrib/requirements/requirements.txt @@ -5,7 +5,6 @@ aiorpcx>=0.22.0,<0.24 aiohttp>=3.3.0,<4.0.0 aiohttp_socks>=0.8.4 certifi -bitstring attrs>=20.1.0 jsonpatch diff --git a/electrum/lnaddr.py b/electrum/lnaddr.py index a27fc696b..e3f94f1c2 100644 --- a/electrum/lnaddr.py +++ b/electrum/lnaddr.py @@ -1,18 +1,17 @@ #! /usr/bin/env python3 # This was forked from https://github.com/rustyrussell/lightning-payencode/tree/acc16ec13a3fa1dc16c07af6ec67c261bd8aff23 +import io import re import time from hashlib import sha256 from binascii import hexlify from decimal import Decimal -from typing import Optional, TYPE_CHECKING, Type, Dict, Any - +from typing import Optional, TYPE_CHECKING, Type, Dict, Any, Union, Sequence, List, Tuple import random -import bitstring from .bitcoin import hash160_to_b58_address, b58_address_to_hash160, TOTAL_COIN_SUPPLY_LIMIT_IN_BTC -from .segwit_addr import bech32_encode, bech32_decode, CHARSET +from .segwit_addr import bech32_encode, bech32_decode, CHARSET, CHARSET_INVERSE, convertbits from . import segwit_addr from . import constants from .constants import AbstractNet @@ -75,25 +74,9 @@ def unshorten_amount(amount) -> Decimal: else: return Decimal(amount) -_INT_TO_BINSTR = {a: '0' * (5-len(bin(a)[2:])) + bin(a)[2:] for a in range(32)} - -# Bech32 spits out array of 5-bit values. Shim here. -def u5_to_bitarray(arr): - b = ''.join(_INT_TO_BINSTR[a] for a in arr) - return bitstring.BitArray(bin=b) - -def bitarray_to_u5(barr): - assert barr.len % 5 == 0 - ret = [] - s = bitstring.ConstBitStream(barr) - while s.pos != s.len: - ret.append(s.read(5).uint) - return ret - -def encode_fallback(fallback: str, net: Type[AbstractNet]): - """ Encode all supported fallback addresses. - """ +def encode_fallback_addr(fallback: str, net: Type[AbstractNet]) -> Sequence[int]: + """Encode all supported fallback addresses.""" wver, wprog_ints = segwit_addr.decode_segwit_address(net.SEGWIT_HRP, fallback) if wver is not None: wprog = bytes(wprog_ints) @@ -106,20 +89,20 @@ def encode_fallback(fallback: str, net: Type[AbstractNet]): else: raise LnEncodeException(f"Unknown address type {addrtype} for {net}") wprog = addr - return tagged('f', bitstring.pack("uint:5", wver) + wprog) + data5 = convertbits(wprog, 8, 5) + assert data5 is not None + return tagged5('f', [wver] + list(data5)) -def parse_fallback(fallback, net: Type[AbstractNet]): - wver = fallback[0:5].uint +def parse_fallback_addr(data5: Sequence[int], net: Type[AbstractNet]) -> Optional[str]: + wver = data5[0] + data8 = bytes(convertbits(data5[1:], 5, 8, False)) if wver == 17: - addr = hash160_to_b58_address(fallback[5:].tobytes(), net.ADDRTYPE_P2PKH) + addr = hash160_to_b58_address(data8, net.ADDRTYPE_P2PKH) elif wver == 18: - addr = hash160_to_b58_address(fallback[5:].tobytes(), net.ADDRTYPE_P2SH) + addr = hash160_to_b58_address(data8, net.ADDRTYPE_P2SH) elif wver <= 16: - witprog = fallback[5:] # cut witver - witprog = witprog[:len(witprog) // 8 * 8] # can only be full bytes - witprog = witprog.tobytes() - addr = segwit_addr.encode_segwit_address(net.SEGWIT_HRP, wver, witprog) + addr = segwit_addr.encode_segwit_address(net.SEGWIT_HRP, wver, data8) else: return None return addr @@ -128,47 +111,52 @@ def parse_fallback(fallback, net: Type[AbstractNet]): BOLT11_HRP_INV_DICT = {net.BOLT11_HRP: net for net in constants.NETS_LIST} -# Tagged field containing BitArray -def tagged(char, l): - # Tagged fields need to be zero-padded to 5 bits. - while l.len % 5 != 0: - l.append('0b0') - return bitstring.pack("uint:5, uint:5, uint:5", - CHARSET.find(char), - (l.len / 5) / 32, (l.len / 5) % 32) + l +def tagged5(char: str, data5: Sequence[int]) -> Sequence[int]: + assert len(data5) < (1 << 10) + return [CHARSET_INVERSE[char], len(data5) >> 5, len(data5) & 31] + data5 + + +def tagged8(char: str, data8: Sequence[int]) -> Sequence[int]: + return tagged5(char, convertbits(data8, 8, 5)) -# Tagged field containing bytes -def tagged_bytes(char, l): - return tagged(char, bitstring.BitArray(l)) -def trim_to_min_length(bits): - """Ensures 'bits' have min number of leading zeroes. - Assumes 'bits' is big-endian, and that it needs to be encoded in 5 bit blocks. +def int_to_data5(val: int, *, bit_len: int = None) -> Sequence[int]: + """Represent big-endian number with as many 0-31 values as it takes. + If `bit_len` is set, use exactly bit_len//5 values (left-padded with zeroes). """ - bits = bits[:] # copy - # make sure we can be split into 5 bit blocks - while bits.len % 5 != 0: - bits.prepend('0b0') - # Get minimal length by trimming leading 5 bits at a time. - while bits.startswith('0b00000'): - if len(bits) == 5: - break # v == 0 - bits = bits[5:] - return bits - -# Discard trailing bits, convert to bytes. -def trim_to_bytes(barr): - # Adds a byte if necessary. - b = barr.tobytes() - if barr.len % 8 != 0: - return b[:-1] - return b - -# Try to pull out tagged data: returns tag, tagged data and remainder. -def pull_tagged(stream): - tag = stream.read(5).uint - length = stream.read(5).uint * 32 + stream.read(5).uint - return (CHARSET[tag], stream.read(length * 5), stream) + if bit_len is not None: + assert bit_len % 5 == 0, bit_len + if val.bit_length() > bit_len: + raise ValueError(f"{val=} too big for {bit_len=!r}") + ret = [] + while val != 0: + ret.append(val % 32) + val //= 32 + if bit_len is not None: + ret.extend([0] * (len(ret) - bit_len // 5)) + ret.reverse() + return ret + + +def int_from_data5(data5: Sequence[int]) -> int: + total = 0 + for v in data5: + total = 32 * total + v + return total + + +def pull_tagged(data5: bytearray) -> Tuple[str, Sequence[int]]: + """Try to pull out tagged data: returns tag, tagged data. Mutates data in-place.""" + if len(data5) < 3: + raise ValueError("Truncated field") + length = data5[1] * 32 + data5[2] + if length > len(data5) - 3: + raise ValueError( + "Truncated {} field: expected {} values".format(CHARSET[data5[0]], length)) + ret = (CHARSET[data5[0]], data5[3:3+length]) + del data5[:3 + length] # much faster than: data5=data5[offset:] + return ret + def lnencode(addr: 'LnAddr', privkey) -> str: if addr.amount: @@ -179,17 +167,17 @@ def lnencode(addr: 'LnAddr', privkey) -> str: hrp = 'ln' + amount # Start with the timestamp - data = bitstring.pack('uint:35', addr.date) + data5 = int_to_data5(addr.date, bit_len=35) tags_set = set() # Payment hash assert addr.paymenthash is not None - data += tagged_bytes('p', addr.paymenthash) + data5 += tagged8('p', addr.paymenthash) tags_set.add('p') if addr.payment_secret is not None: - data += tagged_bytes('s', addr.payment_secret) + data5 += tagged8('s', addr.payment_secret) tags_set.add('s') for k, v in addr.tags: @@ -202,39 +190,44 @@ def lnencode(addr: 'LnAddr', privkey) -> str: raise LnEncodeException("Duplicate '{}' tag".format(k)) if k == 'r': - route = bitstring.BitArray() + route = bytearray() for step in v: - pubkey, channel, feebase, feerate, cltv = step - route.append(bitstring.BitArray(pubkey) + bitstring.BitArray(channel) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv)) - data += tagged('r', route) + pubkey, scid, feebase, feerate, cltv = step + route += pubkey + route += scid + route += int.to_bytes(feebase, length=4, byteorder="big", signed=False) + route += int.to_bytes(feerate, length=4, byteorder="big", signed=False) + route += int.to_bytes(cltv, length=2, byteorder="big", signed=False) + data5 += tagged8('r', route) elif k == 't': pubkey, feebase, feerate, cltv = v - route = bitstring.BitArray(pubkey) + bitstring.pack('intbe:32', feebase) + bitstring.pack('intbe:32', feerate) + bitstring.pack('intbe:16', cltv) - data += tagged('t', route) + route = bytearray() + route += pubkey + route += int.to_bytes(feebase, length=4, byteorder="big", signed=False) + route += int.to_bytes(feerate, length=4, byteorder="big", signed=False) + route += int.to_bytes(cltv, length=2, byteorder="big", signed=False) + data5 += tagged8('t', route) elif k == 'f': if v is not None: - data += encode_fallback(v, addr.net) + data5 += encode_fallback_addr(v, addr.net) elif k == 'd': # truncate to max length: 1024*5 bits = 639 bytes - data += tagged_bytes('d', v.encode()[0:639]) + data5 += tagged8('d', v.encode()[0:639]) elif k == 'x': - expirybits = bitstring.pack('intbe:64', v) - expirybits = trim_to_min_length(expirybits) - data += tagged('x', expirybits) + expirybits = int_to_data5(v) + data5 += tagged5('x', expirybits) elif k == 'h': - data += tagged_bytes('h', sha256(v.encode('utf-8')).digest()) + data5 += tagged8('h', sha256(v.encode('utf-8')).digest()) elif k == 'n': - data += tagged_bytes('n', v) + data5 += tagged8('n', v) elif k == 'c': - finalcltvbits = bitstring.pack('intbe:64', v) - finalcltvbits = trim_to_min_length(finalcltvbits) - data += tagged('c', finalcltvbits) + finalcltvbits = int_to_data5(v) + data5 += tagged5('c', finalcltvbits) elif k == '9': if v == 0: continue - feature_bits = bitstring.BitArray(uint=v, length=v.bit_length()) - feature_bits = trim_to_min_length(feature_bits) - data += tagged('9', feature_bits) + feature_bits = int_to_data5(v) + data5 += tagged5('9', feature_bits) else: # FIXME: Support unknown tags? raise LnEncodeException("Unknown tag {}".format(k)) @@ -251,15 +244,16 @@ def lnencode(addr: 'LnAddr', privkey) -> str: raise ValueError("Must include either 'd' or 'h'") # We actually sign the hrp, then data (padded to 8 bits with zeroes). - msg = hrp.encode("ascii") + data.tobytes() + msg = hrp.encode("ascii") + bytes(convertbits(data5, 5, 8)) msg32 = sha256(msg).digest() privkey = ecc.ECPrivkey(privkey) sig = privkey.ecdsa_sign_recoverable(msg32, is_compressed=False) recovery_flag = bytes([sig[0] - 27]) sig = bytes(sig[1:]) + recovery_flag - data += sig + sig = bytes(convertbits(sig, 8, 5, False)) + data5 += sig - return bech32_encode(segwit_addr.Encoding.BECH32, hrp, bitarray_to_u5(data)) + return bech32_encode(segwit_addr.Encoding.BECH32, hrp, data5) class LnAddr(object): @@ -393,6 +387,7 @@ class SerializableKey: def serialize(self): return self.pubkey.get_public_key_bytes(True) + def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: """Parses a string into an LnAddr object. Can raise LnDecodeException or IncompatibleOrInsaneFeatures. @@ -401,7 +396,7 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: net = constants.net decoded_bech32 = bech32_decode(invoice, ignore_long_length=True) hrp = decoded_bech32.hrp - data = decoded_bech32.data + data5 = decoded_bech32.data # "5" as in list of 5-bit integers if decoded_bech32.encoding is None: raise LnDecodeException("Bad bech32 checksum") if decoded_bech32.encoding != segwit_addr.Encoding.BECH32: @@ -416,13 +411,12 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: if not hrp[2:].startswith(net.BOLT11_HRP): raise LnDecodeException(f"Wrong Lightning invoice HRP {hrp[2:]}, should be {net.BOLT11_HRP}") - data = u5_to_bitarray(data) - # Final signature 65 bytes, split it off. - if len(data) < 65*8: + if len(data5) < 65*8//5: raise LnDecodeException("Too short to contain signature") - sigdecoded = data[-65*8:].tobytes() - data = bitstring.ConstBitStream(data[:-65*8]) + sigdecoded = bytes(convertbits(data5[-65*8//5:], 5, 8, False)) + data5 = data5[:-65*8//5] + data5_remaining = bytearray(data5) # note: bytearray is faster than list of ints addr = LnAddr() addr.pubkey = None @@ -439,17 +433,18 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: if amountstr != '': addr.amount = unshorten_amount(amountstr) - addr.date = data.read(35).uint + addr.date = int_from_data5(data5_remaining[:7]) + data5_remaining = data5_remaining[7:] - while data.pos != data.len: - tag, tagdata, data = pull_tagged(data) + while data5_remaining: + tag, tagdata = pull_tagged(data5_remaining) # mutates arg # BOLT #11: # # A reader MUST skip over unknown fields, an `f` field with unknown # `version`, or a `p`, `h`, or `n` field which does not have # `data_length` 52, 52, or 53 respectively. - data_length = len(tagdata) / 5 + data_length = len(tagdata) if tag == 'r': # BOLT #11: @@ -462,24 +457,43 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: # * `feebase` (32 bits, big-endian) # * `feerate` (32 bits, big-endian) # * `cltv_expiry_delta` (16 bits, big-endian) - route=[] - s = bitstring.ConstBitStream(tagdata) - while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: - route.append((s.read(264).tobytes(), - s.read(64).tobytes(), - s.read(32).uintbe, - s.read(32).uintbe, - s.read(16).uintbe)) - addr.tags.append(('r',route)) + tagdata = convertbits(tagdata, 5, 8, False) + if not tagdata: + continue + route = [] + with io.BytesIO(bytes(tagdata)) as s: + while True: + pubkey = s.read(33) + scid = s.read(8) + feebase = s.read(4) + feerate = s.read(4) + cltv = s.read(2) + if len(cltv) != 2: + break # EOF + feebase = int.from_bytes(feebase, byteorder="big") + feerate = int.from_bytes(feerate, byteorder="big") + cltv = int.from_bytes(cltv, byteorder="big") + route.append((pubkey, scid, feebase, feerate, cltv)) + if route: + addr.tags.append(('r',route)) elif tag == 't': - s = bitstring.ConstBitStream(tagdata) - e = (s.read(264).tobytes(), - s.read(32).uintbe, - s.read(32).uintbe, - s.read(16).uintbe) - addr.tags.append(('t', e)) + tagdata = convertbits(tagdata, 5, 8, False) + if not tagdata: + continue + route = [] + with io.BytesIO(bytes(tagdata)) as s: + pubkey = s.read(33) + feebase = s.read(4) + feerate = s.read(4) + cltv = s.read(2) + if len(cltv) == 2: # no EOF + feebase = int.from_bytes(feebase, byteorder="big") + feerate = int.from_bytes(feerate, byteorder="big") + cltv = int.from_bytes(cltv, byteorder="big") + route.append((pubkey, feebase, feerate, cltv)) + addr.tags.append(('t', route)) elif tag == 'f': - fallback = parse_fallback(tagdata, addr.net) + fallback = parse_fallback_addr(tagdata, addr.net) if fallback: addr.tags.append(('f', fallback)) else: @@ -488,41 +502,41 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: continue elif tag == 'd': - addr.tags.append(('d', trim_to_bytes(tagdata).decode('utf-8'))) + addr.tags.append(('d', bytes(convertbits(tagdata, 5, 8, False)).decode('utf-8'))) elif tag == 'h': if data_length != 52: addr.unknown_tags.append((tag, tagdata)) continue - addr.tags.append(('h', trim_to_bytes(tagdata))) + addr.tags.append(('h', bytes(convertbits(tagdata, 5, 8, False)))) elif tag == 'x': - addr.tags.append(('x', tagdata.uint)) + addr.tags.append(('x', int_from_data5(tagdata))) elif tag == 'p': if data_length != 52: addr.unknown_tags.append((tag, tagdata)) continue - addr.paymenthash = trim_to_bytes(tagdata) + addr.paymenthash = bytes(convertbits(tagdata, 5, 8, False)) elif tag == 's': if data_length != 52: addr.unknown_tags.append((tag, tagdata)) continue - addr.payment_secret = trim_to_bytes(tagdata) + addr.payment_secret = bytes(convertbits(tagdata, 5, 8, False)) elif tag == 'n': if data_length != 53: addr.unknown_tags.append((tag, tagdata)) continue - pubkeybytes = trim_to_bytes(tagdata) + pubkeybytes = bytes(convertbits(tagdata, 5, 8, False)) addr.pubkey = pubkeybytes elif tag == 'c': - addr.tags.append(('c', tagdata.uint)) + addr.tags.append(('c', int_from_data5(tagdata))) elif tag == '9': - features = tagdata.uint + features = int_from_data5(tagdata) addr.tags.append(('9', features)) # note: The features are not validated here in the parser, # instead, validation is done just before we try paying the invoice (in lnworker._check_invoice). @@ -536,16 +550,17 @@ def lndecode(invoice: str, *, verbose=False, net=None) -> LnAddr: print('hex of signature data (32 byte r, 32 byte s): {}' .format(hexlify(sigdecoded[0:64]))) print('recovery flag: {}'.format(sigdecoded[64])) + data8 = bytes(convertbits(data5, 5, 8, True)) print('hex of data for signing: {}' - .format(hexlify(hrp.encode("ascii") + data.tobytes()))) - print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data.tobytes()).hexdigest())) + .format(hexlify(hrp.encode("ascii") + data8))) + print('SHA256 of above: {}'.format(sha256(hrp.encode("ascii") + data8).hexdigest())) # BOLT #11: # # A reader MUST check that the `signature` is valid (see the `n` tagged # field specified below). addr.signature = sigdecoded[:65] - hrp_hash = sha256(hrp.encode("ascii") + data.tobytes()).digest() + hrp_hash = sha256(hrp.encode("ascii") + bytes(convertbits(data5, 5, 8, True))).digest() if addr.pubkey: # Specified by `n` # BOLT #11: # diff --git a/electrum/segwit_addr.py b/electrum/segwit_addr.py index 9d121c3b9..224947086 100644 --- a/electrum/segwit_addr.py +++ b/electrum/segwit_addr.py @@ -22,10 +22,10 @@ """Reference implementation for Bech32/Bech32m and segwit addresses.""" from enum import Enum -from typing import Tuple, Optional, Sequence, NamedTuple, List +from typing import Tuple, Optional, Sequence, NamedTuple, List, Mapping, Iterable CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" -_CHARSET_INVERSE = {c: i for (i, c) in enumerate(CHARSET)} +CHARSET_INVERSE = {c: i for (i, c) in enumerate(CHARSET)} # type: Mapping[str, int] BECH32_CONST = 1 BECH32M_CONST = 0x2bc830a3 @@ -99,7 +99,7 @@ def bech32_decode(bech: str, *, ignore_long_length=False) -> DecodedBech32: bech = bech_lower hrp = bech[:pos] try: - data = [_CHARSET_INVERSE[x] for x in bech[pos+1:]] + data = [CHARSET_INVERSE[x] for x in bech[pos + 1:]] except KeyError: return DecodedBech32(None, None, None) encoding = bech32_verify_checksum(hrp, data) @@ -108,7 +108,7 @@ def bech32_decode(bech: str, *, ignore_long_length=False) -> DecodedBech32: return DecodedBech32(encoding=encoding, hrp=hrp, data=data[:-6]) -def convertbits(data, frombits, tobits, pad=True): +def convertbits(data: Iterable[int], frombits: int, tobits: int, pad: bool = True) -> Optional[Sequence[int]]: """General power-of-2 base conversion.""" acc = 0 bits = 0 diff --git a/electrum/trampoline.py b/electrum/trampoline.py index f1dd92dc6..39a3cc905 100644 --- a/electrum/trampoline.py +++ b/electrum/trampoline.py @@ -1,8 +1,7 @@ +import io import os -import bitstring import random - -from typing import Mapping, DefaultDict, Tuple, Optional, Dict, List, Iterable, Sequence, Set +from typing import Mapping, DefaultDict, Tuple, Optional, Dict, List, Iterable, Sequence, Set, Any from .lnutil import LnFeatures, PaymentFeeBudget from .lnonion import calc_hops_data_for_payment, new_onion_packet, OnionPacket @@ -91,33 +90,38 @@ def trampolines_by_id(): def is_hardcoded_trampoline(node_id: bytes) -> bool: return node_id in trampolines_by_id() -def encode_routing_info(r_tags): - result = bitstring.BitArray() +def encode_routing_info(r_tags: Sequence[Sequence[Sequence[Any]]]) -> bytes: + result = bytearray() for route in r_tags: - result.append(bitstring.pack('uint:8', len(route))) + result += bytes([len(route)]) for step in route: pubkey, scid, feebase, feerate, cltv = step - result.append( - bitstring.BitArray(pubkey) \ - + bitstring.BitArray(scid)\ - + bitstring.pack('intbe:32', feebase)\ - + bitstring.pack('intbe:32', feerate)\ - + bitstring.pack('intbe:16', cltv)) - return result.tobytes() - -def decode_routing_info(s: bytes): - s = bitstring.BitArray(s) + result += pubkey + result += scid + result += int.to_bytes(feebase, length=4, byteorder="big", signed=False) + result += int.to_bytes(feerate, length=4, byteorder="big", signed=False) + result += int.to_bytes(cltv, length=2, byteorder="big", signed=False) + return bytes(result) + + +def decode_routing_info(rinfo: bytes) -> Sequence[Sequence[Sequence[Any]]]: + if not rinfo: + return [] r_tags = [] - n = 8*(33 + 8 + 4 + 4 + 2) - while s: - route = [] - length, s = s[0:8], s[8:] - length = length.unpack('uint:8')[0] - for i in range(length): - chunk, s = s[0:n], s[n:] - item = chunk.unpack('bytes:33, bytes:8, intbe:32, intbe:32, intbe:16') - route.append(item) - r_tags.append(route) + with io.BytesIO(bytes(rinfo)) as s: + while True: + route = [] + route_len = s.read(1) + if not route_len: + break + for step in range(route_len[0]): + pubkey = s.read(33) + scid = s.read(8) + feebase = int.from_bytes(s.read(4), byteorder="big") + feerate = int.from_bytes(s.read(4), byteorder="big") + cltv = int.from_bytes(s.read(2), byteorder="big") + route.append((pubkey, scid, feebase, feerate, cltv)) + r_tags.append(route) return r_tags diff --git a/tests/test_bolt11.py b/tests/test_bolt11.py index 433af331c..c4f756b1b 100644 --- a/tests/test_bolt11.py +++ b/tests/test_bolt11.py @@ -4,7 +4,7 @@ from binascii import unhexlify, hexlify import pprint import unittest -from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode, u5_to_bitarray, bitarray_to_u5 +from electrum.lnaddr import shorten_amount, unshorten_amount, LnAddr, lnencode, lndecode from electrum.segwit_addr import bech32_encode, bech32_decode from electrum import segwit_addr from electrum.lnutil import UnknownEvenFeatureBits, derive_payment_secret_from_payment_preimage, LnFeatures, IncompatibleLightningFeatures @@ -125,19 +125,17 @@ class TestBolt11(ElectrumTestCase): _, hrp, data = bech32_decode( lnencode(LnAddr(paymenthash=RHASH, payment_secret=PAYMENT_SECRET, amount=24, tags=[('d', ''), ('9', 33282)]), PRIVKEY), ignore_long_length=True) - databits = u5_to_bitarray(data) - databits.invert(-1) - lnaddr = lndecode(bech32_encode(segwit_addr.Encoding.BECH32, hrp, bitarray_to_u5(databits)), verbose=True) - assert lnaddr.pubkey.serialize() != PUBKEY + data[-1] ^= 1 + lnaddr = lndecode(bech32_encode(segwit_addr.Encoding.BECH32, hrp, data), verbose=True) + self.assertNotEqual(lnaddr.pubkey.serialize(), PUBKEY) # But not if we supply expliciy `n` specifier! _, hrp, data = bech32_decode( lnencode(LnAddr(paymenthash=RHASH, payment_secret=PAYMENT_SECRET, amount=24, tags=[('d', ''), ('n', PUBKEY), ('9', 33282)]), PRIVKEY), ignore_long_length=True) - databits = u5_to_bitarray(data) - databits.invert(-1) - lnaddr = lndecode(bech32_encode(segwit_addr.Encoding.BECH32, hrp, bitarray_to_u5(databits)), verbose=True) - assert lnaddr.pubkey.serialize() == PUBKEY + data[-1] ^= 1 + lnaddr = lndecode(bech32_encode(segwit_addr.Encoding.BECH32, hrp, data), verbose=True) + self.assertEqual(lnaddr.pubkey.serialize(), PUBKEY) def test_min_final_cltv_expiry_decoding(self): lnaddr = lndecode("lnsb500u1pdsgyf3pp5nmrqejdsdgs4n9ukgxcp2kcq265yhrxd4k5dyue58rxtp5y83s3qsp5qyqszqgpqyqszqgpqyqszqgpqyqszqgpqyqszqgpqyqszqgpqyqsdqqcqzys9qypqsqp2h6a5xeytuc3fad2ed4gxvhd593lwjdna3dxsyeem0qkzjx6guk44jend0xq4zzvp6f3fy07wnmxezazzsxgmvqee8shxjuqu2eu0qpnvc95x",