You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

405 lines
14 KiB

#!/usr/bin/env python3
"""Tests for ChillDKG reference implementation"""
import pytest
from itertools import combinations
from random import randint
from typing import Tuple, List, Optional
from secrets import token_bytes as random_bytes
from jmfrost.secp256k1lab.secp256k1 import GE, G, Scalar
from jmfrost.secp256k1lab.keys import pubkey_gen_plain
from jmfrost.chilldkg_ref.util import (
FaultyParticipantOrCoordinatorError,
FaultyCoordinatorError,
UnknownFaultyParticipantOrCoordinatorError,
tagged_hash_bip_dkg,
)
from jmfrost.chilldkg_ref.vss import Polynomial, VSS, VSSCommitment
import jmfrost.chilldkg_ref.simplpedpop as simplpedpop
import jmfrost.chilldkg_ref.encpedpop as encpedpop
import jmfrost.chilldkg_ref.chilldkg as chilldkg
from chilldkg_example import simulate_chilldkg_full as simulate_chilldkg_full_example
def test_chilldkg_params_validate():
hostseckeys = [random_bytes(32) for _ in range(3)]
hostpubkeys = [chilldkg.hostpubkey_gen(hostseckey) for hostseckey in hostseckeys]
with_duplicate = [hostpubkeys[0], hostpubkeys[1], hostpubkeys[2], hostpubkeys[1]]
params_with_duplicate = chilldkg.SessionParams(with_duplicate, 2)
try:
_ = chilldkg.params_id(params_with_duplicate)
except chilldkg.DuplicateHostPubkeyError as e:
assert {e.participant1, e.participant2} == {1, 3}
else:
assert False, "Expected exception"
invalid_hostpubkey = b"\x03" + 31 * b"\x00" + b"\x05" # Invalid x-coordinate
params_with_invalid = chilldkg.SessionParams(
[hostpubkeys[1], invalid_hostpubkey, hostpubkeys[2]], 1
)
try:
_ = chilldkg.params_id(params_with_invalid)
except chilldkg.InvalidHostPubkeyError as e:
assert e.participant == 1
pass
else:
assert False, "Expected exception"
try:
_ = chilldkg.params_id(
chilldkg.SessionParams(hostpubkeys, len(hostpubkeys) + 1)
)
except chilldkg.ThresholdOrCountError:
pass
else:
assert False, "Expected exception"
try:
_ = chilldkg.params_id(chilldkg.SessionParams(hostpubkeys, -2))
except chilldkg.ThresholdOrCountError:
pass
else:
assert False, "Expected exception"
def test_vss_correctness():
def rand_polynomial(t):
return Polynomial([randint(1, GE.ORDER - 1) for _ in range(1, t + 1)])
for t in range(1, 3):
for n in range(t, 2 * t + 1):
f = rand_polynomial(t)
vss = VSS(f)
secshares = vss.secshares(n)
assert len(secshares) == n
assert all(
VSSCommitment.verify_secshare(secshares[i], vss.commit().pubshare(i))
for i in range(n)
)
vssc_tweaked, tweak, pubtweak = vss.commit().invalid_taproot_commit()
assert VSSCommitment.verify_secshare(
vss.secret() + tweak, vss.commit().commitment_to_secret() + pubtweak
)
assert all(
VSSCommitment.verify_secshare(
secshares[i] + tweak, vssc_tweaked.pubshare(i)
)
for i in range(n)
)
def simulate_simplpedpop(
seeds, t, investigation: bool
) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]:
n = len(seeds)
prets = []
for i in range(n):
prets += [simplpedpop.participant_step1(seeds[i], t, n, i)]
pstates = [pstate for (pstate, _, _) in prets]
pmsgs = [pmsg for (_, pmsg, _) in prets]
cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n)
pre_finalize_rets = [(cout, ceq)]
for i in range(n):
partial_secshares = [
partial_secshares_for[i] for (_, _, partial_secshares_for) in prets
]
if investigation:
# Let a random participant send incorrect shares to participant i.
faulty_idx = randint(0, n - 1)
partial_secshares[faulty_idx] += Scalar(17)
secshare = simplpedpop.participant_step2_prepare_secshare(partial_secshares)
try:
pre_finalize_rets += [
simplpedpop.participant_step2(pstates[i], cmsg, secshare)
]
except UnknownFaultyParticipantOrCoordinatorError as e:
if not investigation:
raise
inv_msgs = simplpedpop.coordinator_investigate(pmsgs)
assert len(inv_msgs) == len(pmsgs)
try:
simplpedpop.participant_investigate(e, inv_msgs[i], partial_secshares)
# If we're not faulty, we should blame the faulty party.
except FaultyParticipantOrCoordinatorError as e:
assert i != faulty_idx
assert e.participant == faulty_idx
# If we're faulty, we'll blame the coordinator.
except FaultyCoordinatorError:
assert i == faulty_idx
return None
return pre_finalize_rets
def encpedpop_keys(seed: bytes) -> Tuple[bytes, bytes]:
deckey = tagged_hash_bip_dkg("encpedpop deckey", seed)
enckey = pubkey_gen_plain(deckey)
return deckey, enckey
def simulate_encpedpop(
seeds, t, investigation: bool
) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]:
n = len(seeds)
enc_prets0 = []
enc_prets1 = []
for i in range(n):
enc_prets0 += [encpedpop_keys(seeds[i])]
enckeys = [pret[1] for pret in enc_prets0]
for i in range(n):
deckey = enc_prets0[i][0]
random = random_bytes(32)
enc_prets1 += [
encpedpop.participant_step1(seeds[i], deckey, enckeys, t, i, random)
]
pstates = [pstate for (pstate, _) in enc_prets1]
pmsgs = [pmsg for (_, pmsg) in enc_prets1]
if investigation:
faulty_idx: List[int] = []
for i in range(n):
# Let a random participant faulty_idx[i] send incorrect shares to i.
faulty_idx[i:] = [randint(0, n - 1)]
pmsgs[faulty_idx[i]].enc_shares[i] += Scalar(17)
cmsg, cout, ceq, enc_secshares = encpedpop.coordinator_step(pmsgs, t, enckeys)
pre_finalize_rets = [(cout, ceq)]
for i in range(n):
deckey = enc_prets0[i][0]
try:
pre_finalize_rets += [
encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i])
]
except UnknownFaultyParticipantOrCoordinatorError as e:
if not investigation:
raise
inv_msgs = encpedpop.coordinator_investigate(pmsgs)
assert len(inv_msgs) == len(pmsgs)
try:
encpedpop.participant_investigate(e, inv_msgs[i])
# If we're not faulty, we should blame the faulty party.
except FaultyParticipantOrCoordinatorError as e:
assert i != faulty_idx[i]
assert e.participant == faulty_idx[i]
# If we're faulty, we'll blame the coordinator.
except FaultyCoordinatorError:
assert i == faulty_idx[i]
return None
return pre_finalize_rets
def simulate_chilldkg(
hostseckeys, t, investigation: bool
) -> Optional[List[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]]:
n = len(hostseckeys)
hostpubkeys = []
for i in range(n):
hostpubkeys += [chilldkg.hostpubkey_gen(hostseckeys[i])]
params = chilldkg.SessionParams(hostpubkeys, t)
prets1 = []
for i in range(n):
random = random_bytes(32)
prets1 += [chilldkg.participant_step1(hostseckeys[i], params, random)]
pstates1 = [pret[0] for pret in prets1]
pmsgs = [pret[1] for pret in prets1]
if investigation:
faulty_idx: List[int] = []
for i in range(n):
# Let a random participant faulty_idx[i] send incorrect shares to i.
faulty_idx[i:] = [randint(0, n - 1)]
pmsgs[faulty_idx[i]].enc_pmsg.enc_shares[i] += Scalar(17)
cstate, cmsg1 = chilldkg.coordinator_step1(pmsgs, params)
prets2 = []
for i in range(n):
try:
prets2 += [chilldkg.participant_step2(hostseckeys[i], pstates1[i], cmsg1)]
except UnknownFaultyParticipantOrCoordinatorError as e:
if not investigation:
raise
inv_msgs = chilldkg.coordinator_investigate(pmsgs)
assert len(inv_msgs) == len(pmsgs)
try:
chilldkg.participant_investigate(e, inv_msgs[i])
# If we're not faulty, we should blame the faulty party.
except FaultyParticipantOrCoordinatorError as e:
assert i != faulty_idx[i]
assert e.participant == faulty_idx[i]
# If we're faulty, we'll blame the coordinator.
except FaultyCoordinatorError:
assert i == faulty_idx[i]
return None
cmsg2, cout, crec = chilldkg.coordinator_finalize(
cstate, [pret[1] for pret in prets2]
)
outputs = [(cout, crec)]
for i in range(n):
out = chilldkg.participant_finalize(prets2[i][0], cmsg2)
assert out is not None
outputs += [out]
return outputs
def simulate_chilldkg_full(
hostseckeys,
t,
investigation: bool,
) -> List[Optional[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]]:
# Investigating is not supported by this wrapper
assert not investigation
hostpubkeys = []
n = len(hostseckeys)
for i in range(n):
hostpubkeys += [chilldkg.hostpubkey_gen(hostseckeys[i])]
params = chilldkg.SessionParams(hostpubkeys, t)
return simulate_chilldkg_full_example(hostseckeys, params, faulty_idx=None)
def derive_interpolating_value(L, x_i):
assert x_i in L
assert all(L.count(x_j) <= 1 for x_j in L)
lam = Scalar(1)
for x_j in L:
x_j = Scalar(x_j)
x_i = Scalar(x_i)
if x_j == x_i:
continue
lam *= x_j / (x_j - x_i)
return lam
def recover_secret(participant_indices, shares) -> Scalar:
interpolated_shares = []
t = len(shares)
assert len(participant_indices) == t
for i in range(t):
lam = derive_interpolating_value(participant_indices, participant_indices[i])
interpolated_shares += [(lam * shares[i])]
recovered_secret = Scalar.sum(*interpolated_shares)
return recovered_secret
def test_recover_secret():
f = Polynomial([23, 42])
shares = [f(i) for i in [1, 2, 3]]
assert recover_secret([1, 2], [shares[0], shares[1]]) == f.coeffs[0]
assert recover_secret([1, 3], [shares[0], shares[2]]) == f.coeffs[0]
assert recover_secret([2, 3], [shares[1], shares[2]]) == f.coeffs[0]
def check_correctness_dkg_output(t, n, dkg_outputs: List[simplpedpop.DKGOutput]):
assert len(dkg_outputs) == n + 1
secshares = [out[0] for out in dkg_outputs]
threshold_pubkeys = [out[1] for out in dkg_outputs]
pubshares = [out[2] for out in dkg_outputs]
# Check that the threshold pubkey and pubshares are the same for the
# coordinator (at [0]) and all participants (at [1:n + 1]).
for i in range(n + 1):
assert threshold_pubkeys[0] == threshold_pubkeys[i]
assert len(pubshares[i]) == n
assert pubshares[0] == pubshares[i]
threshold_pubkey = threshold_pubkeys[0]
# Check that the coordinator has no secret share
assert secshares[0] is None
# Check that each secshare matches the corresponding pubshare
secshares_scalar = [
None if secshare is None else Scalar.from_bytes_checked(secshare)
for secshare in secshares
]
for i in range(1, n + 1):
assert secshares_scalar[i] * G == GE.from_bytes_compressed(pubshares[0][i - 1])
# Check that all combinations of t participants can recover the threshold pubkey
for tsubset in combinations(range(1, n + 1), t):
recovered = recover_secret(tsubset, [secshares_scalar[i] for i in tsubset])
assert recovered * G == GE.from_bytes_compressed(threshold_pubkey)
@pytest.mark.parametrize('t,n,simulate_dkg,recovery,investigation', [
[1, 1, simulate_simplpedpop, False, False],
[1, 1, simulate_simplpedpop, False, True],
[1, 1, simulate_encpedpop, False, False],
[1, 1, simulate_encpedpop, False, True],
[1, 1, simulate_chilldkg, True, False],
[1, 1, simulate_chilldkg, True, True],
[1, 1, simulate_chilldkg_full, True, False],
[1, 2, simulate_simplpedpop, False, False],
[1, 2, simulate_simplpedpop, False, True],
[1, 2, simulate_encpedpop, False, False],
[1, 2, simulate_encpedpop, False, True],
[1, 2, simulate_chilldkg, True, False],
[1, 2, simulate_chilldkg, True, True],
[1, 2, simulate_chilldkg_full, True, False],
[2, 2, simulate_simplpedpop, False, False],
[2, 2, simulate_simplpedpop, False, True],
[2, 2, simulate_encpedpop, False, False],
[2, 2, simulate_encpedpop, False, True],
[2, 2, simulate_chilldkg, True, False],
[2, 2, simulate_chilldkg, True, True],
[2, 2, simulate_chilldkg_full, True, False],
[2, 3, simulate_simplpedpop, False, False],
[2, 3, simulate_simplpedpop, False, True],
[2, 3, simulate_encpedpop, False, False],
[2, 3, simulate_encpedpop, False, True],
[2, 3, simulate_chilldkg, True, False],
[2, 3, simulate_chilldkg, True, True],
[2, 3, simulate_chilldkg_full, True, False],
[2, 5, simulate_simplpedpop, False, False],
[2, 5, simulate_simplpedpop, False, True],
[2, 5, simulate_encpedpop, False, False],
[2, 5, simulate_encpedpop, False, True],
[2, 5, simulate_chilldkg, True, False],
[2, 5, simulate_chilldkg, True, True],
[2, 5, simulate_chilldkg_full, True, False],
])
def test_correctness(t, n, simulate_dkg, recovery, investigation):
seeds = [None] + [random_bytes(32) for _ in range(n)]
rets = simulate_dkg(seeds[1:], t, investigation=investigation)
if investigation:
assert rets is None
# The session has failed correctly, so there's nothing further to check.
return
# rets[0] are the return values from the coordinator
# rets[1 : n + 1] are from the participants
assert len(rets) == n + 1
dkg_outputs = [ret[0] for ret in rets]
check_correctness_dkg_output(t, n, dkg_outputs)
eqs_or_recs = [ret[1] for ret in rets]
for i in range(1, n + 1):
assert eqs_or_recs[0] == eqs_or_recs[i]
if recovery:
rec = eqs_or_recs[0]
# Check correctness of chilldkg.recover
for i in range(n + 1):
(secshare, threshold_pubkey, pubshares), _ = chilldkg.recover(seeds[i], rec)
assert secshare == dkg_outputs[i][0]
assert threshold_pubkey == dkg_outputs[i][1]
assert pubshares == dkg_outputs[i][2]