Browse Source

descriptors: add more sanity checks, and unit tests

master
SomberNight 3 years ago
parent
commit
144aac4523
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 81
      electrum/descriptor.py
  2. 102
      electrum/tests/test_descriptor.py

81
electrum/descriptor.py

@ -9,21 +9,12 @@
# See https://github.com/bitcoin/bitcoin/blob/master/doc/descriptors.md
#
# TODO allow xprv
# TODO hardened derivation
# TODO allow WIF privkeys
# TODO impl ADDR descriptors
# TODO impl RAW descriptors
# TODO disable descs we cannot solve: TRDescriptor
# TODO add checks to validate nestings
# https://github.com/bitcoin/bitcoin/blob/94070029fb6b783833973f9fe08a3a871994492f/doc/descriptors.md#reference
# e.g. sh is top-level only, wsh is top-level or directly inside sh
#
# TODO tests
# - port https://github.com/bitcoin-core/HWI/blob/master/test/test_descriptor.py
# - ranged descriptors (that have a "*")
#
# TODO solver? integrate with transaction.py...
# Transaction.input_script/get_preimage_script/serialize_witness
import enum
from .bip32 import convert_bip32_path_to_list_of_uint32, BIP32Node, KeyOriginInfo
from . import bitcoin
@ -188,8 +179,13 @@ class PubkeyProvider(object):
self.origin = origin
self.pubkey = pubkey
self.deriv_path = deriv_path
# TODO check that deriv_path only has a single "*" (and that it is in the last pos. but can end with e.g. "*h")
if deriv_path:
wildcard_count = deriv_path.count("*")
if wildcard_count > 1:
raise ValueError("only one wildcard(*) is allowed in a descriptor")
if wildcard_count == 1:
if deriv_path[-1] != "*":
raise ValueError("wildcard in descriptor only allowed in last position")
# Make ExtendedKey from pubkey if it isn't hex
self.extkey = None
try:
@ -240,6 +236,7 @@ class PubkeyProvider(object):
def get_pubkey_bytes(self, *, pos: Optional[int] = None) -> bytes:
if self.is_range() and pos is None:
raise ValueError("pos must be set for ranged descriptor")
# note: if not ranged, we ignore pos.
if self.extkey is not None:
compressed = True # bip32 implies compressed pubkeys
if self.deriv_path is None:
@ -298,11 +295,21 @@ class PubkeyProvider(object):
return True
return False
def has_uncompressed_pubkey(self) -> bool:
if self.is_range(): # bip32 implies compressed
return False
return b"\x04" == self.get_pubkey_bytes()[:1]
class Descriptor(object):
r"""
An abstract class for Descriptors themselves.
Descriptors can contain multiple :class:`PubkeyProvider`\ s and multiple ``Descriptor`` as subdescriptors.
Note: a significant portion of input validation logic is in parse_descriptor(),
maybe these checks should be moved to (or also done in) this class?
For example, sh() must be top-level, or segwit mandates compressed pubkeys,
or bare-multisig cannot have >3 pubkeys.
"""
def __init__(
self,
@ -823,7 +830,7 @@ def _get_expr(s: str) -> Tuple[str, str]:
break
return s[0:i], s[i:]
def parse_pubkey(expr: str) -> Tuple['PubkeyProvider', str]:
def parse_pubkey(expr: str, *, ctx: '_ParseDescriptorContext') -> Tuple['PubkeyProvider', str]:
"""
Parses an individual pubkey expression from a string that may contain more than one pubkey expression.
@ -836,7 +843,11 @@ def parse_pubkey(expr: str) -> Tuple['PubkeyProvider', str]:
if comma_idx != -1:
end = comma_idx
next_expr = expr[end + 1:]
return PubkeyProvider.parse(expr[:end]), next_expr
pubkey_provider = PubkeyProvider.parse(expr[:end])
permit_uncompressed = ctx in (_ParseDescriptorContext.TOP, _ParseDescriptorContext.P2SH)
if not permit_uncompressed and pubkey_provider.has_uncompressed_pubkey():
raise ValueError("uncompressed pubkeys are not allowed")
return pubkey_provider, next_expr
class _ParseDescriptorContext(Enum):
@ -847,20 +858,14 @@ class _ParseDescriptorContext(Enum):
Some expressions aren't allowed at certain levels, this helps us track those.
"""
TOP = 1
"""The top level, not within any descriptor"""
P2SH = 2
"""Within a ``sh()`` descriptor"""
P2WSH = 3
"""Within a ``wsh()`` descriptor"""
P2TR = 4
"""Within a ``tr()`` descriptor"""
TOP = enum.auto() # The top level, not within any descriptor
P2SH = enum.auto() # Within an sh() descriptor
P2WPKH = enum.auto() # Within wpkh() descriptor
P2WSH = enum.auto() # Within a wsh() descriptor
P2TR = enum.auto() # Within a tr() descriptor
def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor':
def _parse_descriptor(desc: str, *, ctx: '_ParseDescriptorContext') -> 'Descriptor':
"""
:meta private:
@ -874,14 +879,14 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
"""
func, expr = _get_func_expr(desc)
if func == "pk":
pubkey, expr = parse_pubkey(expr)
pubkey, expr = parse_pubkey(expr, ctx=ctx)
if expr:
raise ValueError("more than one pubkey in pk descriptor")
return PKDescriptor(pubkey)
if func == "pkh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH or ctx == _ParseDescriptorContext.P2WSH):
raise ValueError("Can only have pkh at top level, in sh(), or in wsh()")
pubkey, expr = parse_pubkey(expr)
pubkey, expr = parse_pubkey(expr, ctx=ctx)
if expr:
raise ValueError("More than one pubkey in pkh descriptor")
return PKHDescriptor(pubkey)
@ -894,10 +899,10 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
expr = expr[comma_idx + 1:]
pubkeys = []
while expr:
pubkey, expr = parse_pubkey(expr)
pubkey, expr = parse_pubkey(expr, ctx=ctx)
pubkeys.append(pubkey)
if len(pubkeys) == 0 or len(pubkeys) > 16:
raise ValueError("Cannot have {} keys in a multisig; must have between 1 and 16 keys, inclusive".format(len(pubkeys)))
if len(pubkeys) == 0 or len(pubkeys) > 15:
raise ValueError("Cannot have {} keys in a multisig; must have between 1 and 15 keys, inclusive".format(len(pubkeys)))
elif thresh < 1:
raise ValueError("Multisig threshold cannot be {}, must be at least 1".format(thresh))
elif thresh > len(pubkeys):
@ -908,24 +913,24 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
if func == "wpkh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH):
raise ValueError("Can only have wpkh() at top level or inside sh()")
pubkey, expr = parse_pubkey(expr)
pubkey, expr = parse_pubkey(expr, ctx=_ParseDescriptorContext.P2WPKH)
if expr:
raise ValueError("More than one pubkey in pkh descriptor")
return WPKHDescriptor(pubkey)
if func == "sh":
if ctx != _ParseDescriptorContext.TOP:
raise ValueError("Can only have sh() at top level")
subdesc = _parse_descriptor(expr, _ParseDescriptorContext.P2SH)
subdesc = _parse_descriptor(expr, ctx=_ParseDescriptorContext.P2SH)
return SHDescriptor(subdesc)
if func == "wsh":
if not (ctx == _ParseDescriptorContext.TOP or ctx == _ParseDescriptorContext.P2SH):
raise ValueError("Can only have wsh() at top level or inside sh()")
subdesc = _parse_descriptor(expr, _ParseDescriptorContext.P2WSH)
subdesc = _parse_descriptor(expr, ctx=_ParseDescriptorContext.P2WSH)
return WSHDescriptor(subdesc)
if func == "tr":
if ctx != _ParseDescriptorContext.TOP:
raise ValueError("Can only have tr at top level")
internal_key, expr = parse_pubkey(expr)
internal_key, expr = parse_pubkey(expr, ctx=ctx)
subscripts = []
depths = []
if expr:
@ -945,7 +950,7 @@ def _parse_descriptor(desc: str, ctx: '_ParseDescriptorContext') -> 'Descriptor'
raise ValueError(f"tr() supports at most {MAX_TAPROOT_NODES} nesting levels") # TODO xxxx fixed upstream bug here
# Process script expression
sarg, expr = _get_expr(expr)
subscripts.append(_parse_descriptor(sarg, _ParseDescriptorContext.P2TR))
subscripts.append(_parse_descriptor(sarg, ctx=_ParseDescriptorContext.P2TR))
depths.append(len(branches))
# Process closing braces
while len(branches) > 0 and branches[-1]:
@ -982,7 +987,7 @@ def parse_descriptor(desc: str) -> 'Descriptor':
computed = DescriptorChecksum(desc)
if computed != checksum:
raise ValueError("The checksum does not match; Got {}, expected {}".format(checksum, computed))
return _parse_descriptor(desc, _ParseDescriptorContext.TOP)
return _parse_descriptor(desc, ctx=_ParseDescriptorContext.TOP)
#####

102
electrum/tests/test_descriptor.py

@ -1,10 +1,12 @@
# Copyright (c) 2018-2023 The HWI developers
# Copyright (c) 2023 The Electrum developers
# Distributed under the MIT software license, see the accompanying
# file COPYING or http://www.opensource.org/licenses/mit-license.php.
#
# originally from https://github.com/bitcoin-core/HWI/blob/f5a9b29c00e483cc99a1b8f4f5ef75413a092869/test/test_descriptor.py
from binascii import unhexlify
import unittest
from electrum.descriptor import (
parse_descriptor,
@ -15,6 +17,8 @@ from electrum.descriptor import (
WPKHDescriptor,
WSHDescriptor,
)
from electrum import ecc
from electrum.util import bfh
from . import ElectrumTestCase, as_testnet
@ -35,6 +39,7 @@ class TestDescriptor(ElectrumTestCase):
self.assertEqual(e.output_script, unhexlify("0014d95fc47eada9e4c3cf59a2cbf9e96517c3ba2efa"))
self.assertEqual(e.redeem_script, None)
self.assertEqual(e.witness_script, None)
self.assertEqual(e.address(), "tb1qm90ugl4d48jv8n6e5t9ln6t9zlpm5th690vysp")
@as_testnet
def test_parse_multisig_descriptor_with_origin(self):
@ -216,3 +221,100 @@ class TestDescriptor(ElectrumTestCase):
self.assertEqual(desc.pubkeys[0].deriv_path, "/0/0")
self.assertEqual(desc.depths, [1, 3, 3, 2])
self.assertEqual(desc.to_string_no_checksum(), d)
@as_testnet
def test_parse_descriptor_with_range(self):
d = "wpkh([00000001/84h/1h/0h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/*)"
desc = parse_descriptor(d)
self.assertTrue(isinstance(desc, WPKHDescriptor))
self.assertEqual(desc.pubkeys[0].origin.fingerprint.hex(), "00000001")
self.assertEqual(desc.pubkeys[0].origin.get_derivation_path(), "m/84h/1h/0h")
self.assertEqual(desc.pubkeys[0].pubkey, "tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B")
self.assertEqual(desc.pubkeys[0].deriv_path, "/0/*")
self.assertEqual(desc.to_string_no_checksum(), d)
with self.assertRaises(ValueError): # "pos" arg needed due to "*"
e = desc.expand()
e = desc.expand(pos=7)
self.assertEqual(e.output_script, unhexlify("0014c5f80de08f6ae8dd720bf4e4948ba498c96256a1"))
self.assertEqual(e.redeem_script, None)
self.assertEqual(e.witness_script, None)
with self.assertRaises(ValueError): # wildcard only allowed in last position
parse_descriptor("wpkh([00000001/84h/1h/0h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/*/0)")
with self.assertRaises(ValueError): # only one wildcard(*) is allowed
parse_descriptor("wpkh([00000001/84h/1h/0h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/*/*)")
@as_testnet
def test_parse_multisig_descriptor_with_range(self):
d = "wsh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/*,[00000002/48h/0h/0h/2h]tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty/0/*))"
desc = parse_descriptor(d)
self.assertTrue(isinstance(desc, WSHDescriptor))
self.assertTrue(isinstance(desc.subdescriptors[0], MultisigDescriptor))
self.assertEqual(desc.subdescriptors[0].pubkeys[0].origin.fingerprint.hex(), "00000001")
self.assertEqual(desc.subdescriptors[0].pubkeys[0].origin.get_derivation_path(), "m/48h/0h/0h/2h")
self.assertEqual(desc.subdescriptors[0].pubkeys[0].pubkey, "tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B")
self.assertEqual(desc.subdescriptors[0].pubkeys[0].deriv_path, "/0/*")
self.assertEqual(desc.subdescriptors[0].pubkeys[1].origin.fingerprint.hex(), "00000002")
self.assertEqual(desc.subdescriptors[0].pubkeys[1].origin.get_derivation_path(), "m/48h/0h/0h/2h")
self.assertEqual(desc.subdescriptors[0].pubkeys[1].pubkey, "tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty")
self.assertEqual(desc.subdescriptors[0].pubkeys[1].deriv_path, "/0/*")
self.assertEqual(desc.to_string_no_checksum(), d)
e = desc.expand(pos=7)
self.assertEqual(e.output_script, unhexlify("0020453cdf90aef0997947bc0605481f81dd2978ecd2d04ac36fb57397a82341682d"))
self.assertEqual(e.redeem_script, None)
self.assertEqual(e.witness_script, unhexlify("5221034e703dfcd64f23ad5d6156ee3b9dd7566137626c663bb521bf710957599723342102c35627535d26de98ae749b7a7849df99cbe53af795005437ca647c8af9a006af52ae"))
@as_testnet
def test_multisig_descriptor_with_mixed_range(self):
d = "sh(wsh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/*,[00000002/48h/0h/0h/2h]tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty/0/0)))"
desc = parse_descriptor(d)
e = desc.expand(pos=7)
self.assertEqual(e.output_script, bfh("a914644ece12bab2f84ad6de96ec18de51e6168c028987"))
self.assertEqual(e.redeem_script, bfh("0020824ce4ffab74a8d09c2f77ed447fb040ea5dfbed06f8e3b3327127a18634f6a7"))
self.assertEqual(e.witness_script, bfh("5221034e703dfcd64f23ad5d6156ee3b9dd7566137626c663bb521bf7109575997233421033a4f18d2b498273ed7439c59f6d8a673d5b9c67a03163d530e12c941ca22be3352ae"))
self.assertEqual(e.address(), "2N2Pbxw3HNJ9jrUw8LCSfXyDWx9TKGRT2an")
@as_testnet
def test_uncompressed_pubkey_in_segwit(self):
pubkey = ecc.ECPubkey(bfh("02a0507c8bb3d96dfd7731bafb0ae30e6ed10bbadd6a9f9f88eaf0602b9cc99adc"))
pubkey_comp_hex = pubkey.get_public_key_hex(compressed=True)
pubkey_uncomp_hex = pubkey.get_public_key_hex(compressed=False)
self.assertEqual(pubkey_comp_hex, "02a0507c8bb3d96dfd7731bafb0ae30e6ed10bbadd6a9f9f88eaf0602b9cc99adc")
self.assertEqual(pubkey_uncomp_hex, "04a0507c8bb3d96dfd7731bafb0ae30e6ed10bbadd6a9f9f88eaf0602b9cc99adc3ccfc29410b8f23c15d88413a6b88c8cd44b016a7f1dd91a8d64c3107c6bce1a")
# pkh
desc = parse_descriptor(f"pkh({pubkey_comp_hex})")
self.assertEqual(desc.expand().output_script, bfh("76a9140297bde2689a3c79ffe050583b62f86f2d9dae5488ac"))
desc = parse_descriptor(f"pkh({pubkey_uncomp_hex})")
self.assertEqual(desc.expand().output_script, bfh("76a914e1f4a76b122f0288b013404cd52a9d1de0ced3c488ac"))
# wpkh
desc = parse_descriptor(f"wpkh({pubkey_comp_hex})")
self.assertEqual(desc.expand().output_script, bfh("00140297bde2689a3c79ffe050583b62f86f2d9dae54"))
with self.assertRaises(ValueError): # only compressed public keys can be used in segwit scripts
desc = parse_descriptor(f"wpkh({pubkey_uncomp_hex})")
# sh(wsh(multi()))
desc = parse_descriptor(f"sh(wsh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/*,{pubkey_comp_hex})))")
self.assertEqual(desc.expand(pos=2).output_script, bfh("a9148f162cce29ad81e63ed45cd09aff83418316eab687"))
with self.assertRaises(ValueError): # only compressed public keys can be used in segwit scripts
desc = parse_descriptor(f"sh(wsh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/*,{pubkey_uncomp_hex})))")
@as_testnet
def test_parse_descriptor_context(self):
desc = parse_descriptor("sh(wsh(sortedmulti(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/0,[00000002/48h/0h/0h/2h]tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty/0/0)))")
self.assertTrue(isinstance(desc, SHDescriptor))
with self.assertRaises(ValueError): # Can only have sh() at top level
desc = parse_descriptor("wsh(sh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/0,[00000002/48h/0h/0h/2h]tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty/0/0)))")
with self.assertRaises(ValueError): # Can only have wsh() at top level or inside sh()
desc = parse_descriptor("wsh(wsh(multi(2,[00000001/48h/0h/0h/2h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/0,[00000002/48h/0h/0h/2h]tpubDFHiBJDeNvqPWNJbzzxqDVXmJZoNn2GEtoVcFhMjXipQiorGUmps3e5ieDGbRrBPTFTh9TXEKJCwbAGW9uZnfrVPbMxxbFohuFzfT6VThty/0/0)))")
desc = parse_descriptor("wpkh([00000001/84h/1h/0h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/0)")
self.assertTrue(isinstance(desc, WPKHDescriptor))
with self.assertRaises(ValueError): # Can only have wpkh() at top level or inside sh()
desc = parse_descriptor("wsh(wpkh([00000001/84h/1h/0h]tpubD6NzVbkrYhZ4WaWSyoBvQwbpLkojyoTZPRsgXELWz3Popb3qkjcJyJUGLnL4qHHoQvao8ESaAstxYSnhyswJ76uZPStJRJCTKvosUCJZL5B/0/0))")
def test_parse_descriptor_ypub_zpub_forbidden(self):
desc = parse_descriptor("wpkh([535e473f/0h]xpub68W3CJPrQzHhTQcHM6tbCvNVB9ih4tbzsFBLwe7zZUj5uHuhxBUhvnXe1RQhbKCTiTj3D7kXni6yAD88i2xnjKHaJ5NqTtHawKnPFCDnmo4/0/*)")
with self.assertRaises(ValueError): # only standard xpub/xprv allowed
desc = parse_descriptor("wpkh([535e473f/0h]ypub6TLJVy4mZfqBJhoQBTgDR1TzM7s91WbVnMhZj31swV6xxPiwCqeGYrBn2dNHbDrP86qqxbM6FNTX3VjhRjNoXYyBAR5G3o75D3r2djmhZwM/0/*)")
with self.assertRaises(ValueError): # only standard xpub/xprv allowed
desc = parse_descriptor("wpkh([535e473f/0h]zpub6nAZodjgiMNf9zzX1pTqd6ZVX61ax8azhUDnWRumKVUr1VYATVoqAuqv3qKsb8WJXjxei4wei2p4vnMG9RnpKnen2kmgdhvZUmug2NnHNsr/0/*)")

Loading…
Cancel
Save