You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
405 lines
14 KiB
405 lines
14 KiB
#!/usr/bin/env python3 |
|
|
|
"""Tests for ChillDKG reference implementation""" |
|
|
|
import pytest |
|
from itertools import combinations |
|
from random import randint |
|
from typing import Tuple, List, Optional |
|
from secrets import token_bytes as random_bytes |
|
|
|
from jmfrost.secp256k1lab.secp256k1 import GE, G, Scalar |
|
from jmfrost.secp256k1lab.keys import pubkey_gen_plain |
|
|
|
from jmfrost.chilldkg_ref.util import ( |
|
FaultyParticipantOrCoordinatorError, |
|
FaultyCoordinatorError, |
|
UnknownFaultyParticipantOrCoordinatorError, |
|
tagged_hash_bip_dkg, |
|
) |
|
from jmfrost.chilldkg_ref.vss import Polynomial, VSS, VSSCommitment |
|
import jmfrost.chilldkg_ref.simplpedpop as simplpedpop |
|
import jmfrost.chilldkg_ref.encpedpop as encpedpop |
|
import jmfrost.chilldkg_ref.chilldkg as chilldkg |
|
|
|
from chilldkg_example import simulate_chilldkg_full as simulate_chilldkg_full_example |
|
|
|
|
|
def test_chilldkg_params_validate(): |
|
hostseckeys = [random_bytes(32) for _ in range(3)] |
|
hostpubkeys = [chilldkg.hostpubkey_gen(hostseckey) for hostseckey in hostseckeys] |
|
|
|
with_duplicate = [hostpubkeys[0], hostpubkeys[1], hostpubkeys[2], hostpubkeys[1]] |
|
params_with_duplicate = chilldkg.SessionParams(with_duplicate, 2) |
|
try: |
|
_ = chilldkg.params_id(params_with_duplicate) |
|
except chilldkg.DuplicateHostPubkeyError as e: |
|
assert {e.participant1, e.participant2} == {1, 3} |
|
else: |
|
assert False, "Expected exception" |
|
|
|
invalid_hostpubkey = b"\x03" + 31 * b"\x00" + b"\x05" # Invalid x-coordinate |
|
params_with_invalid = chilldkg.SessionParams( |
|
[hostpubkeys[1], invalid_hostpubkey, hostpubkeys[2]], 1 |
|
) |
|
try: |
|
_ = chilldkg.params_id(params_with_invalid) |
|
except chilldkg.InvalidHostPubkeyError as e: |
|
assert e.participant == 1 |
|
pass |
|
else: |
|
assert False, "Expected exception" |
|
|
|
try: |
|
_ = chilldkg.params_id( |
|
chilldkg.SessionParams(hostpubkeys, len(hostpubkeys) + 1) |
|
) |
|
except chilldkg.ThresholdOrCountError: |
|
pass |
|
else: |
|
assert False, "Expected exception" |
|
|
|
try: |
|
_ = chilldkg.params_id(chilldkg.SessionParams(hostpubkeys, -2)) |
|
except chilldkg.ThresholdOrCountError: |
|
pass |
|
else: |
|
assert False, "Expected exception" |
|
|
|
|
|
def test_vss_correctness(): |
|
def rand_polynomial(t): |
|
return Polynomial([randint(1, GE.ORDER - 1) for _ in range(1, t + 1)]) |
|
|
|
for t in range(1, 3): |
|
for n in range(t, 2 * t + 1): |
|
f = rand_polynomial(t) |
|
vss = VSS(f) |
|
secshares = vss.secshares(n) |
|
assert len(secshares) == n |
|
assert all( |
|
VSSCommitment.verify_secshare(secshares[i], vss.commit().pubshare(i)) |
|
for i in range(n) |
|
) |
|
|
|
vssc_tweaked, tweak, pubtweak = vss.commit().invalid_taproot_commit() |
|
assert VSSCommitment.verify_secshare( |
|
vss.secret() + tweak, vss.commit().commitment_to_secret() + pubtweak |
|
) |
|
assert all( |
|
VSSCommitment.verify_secshare( |
|
secshares[i] + tweak, vssc_tweaked.pubshare(i) |
|
) |
|
for i in range(n) |
|
) |
|
|
|
|
|
def simulate_simplpedpop( |
|
seeds, t, investigation: bool |
|
) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]: |
|
n = len(seeds) |
|
prets = [] |
|
for i in range(n): |
|
prets += [simplpedpop.participant_step1(seeds[i], t, n, i)] |
|
|
|
pstates = [pstate for (pstate, _, _) in prets] |
|
pmsgs = [pmsg for (_, pmsg, _) in prets] |
|
|
|
cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n) |
|
pre_finalize_rets = [(cout, ceq)] |
|
for i in range(n): |
|
partial_secshares = [ |
|
partial_secshares_for[i] for (_, _, partial_secshares_for) in prets |
|
] |
|
if investigation: |
|
# Let a random participant send incorrect shares to participant i. |
|
faulty_idx = randint(0, n - 1) |
|
partial_secshares[faulty_idx] += Scalar(17) |
|
|
|
secshare = simplpedpop.participant_step2_prepare_secshare(partial_secshares) |
|
try: |
|
pre_finalize_rets += [ |
|
simplpedpop.participant_step2(pstates[i], cmsg, secshare) |
|
] |
|
except UnknownFaultyParticipantOrCoordinatorError as e: |
|
if not investigation: |
|
raise |
|
inv_msgs = simplpedpop.coordinator_investigate(pmsgs) |
|
assert len(inv_msgs) == len(pmsgs) |
|
try: |
|
simplpedpop.participant_investigate(e, inv_msgs[i], partial_secshares) |
|
# If we're not faulty, we should blame the faulty party. |
|
except FaultyParticipantOrCoordinatorError as e: |
|
assert i != faulty_idx |
|
assert e.participant == faulty_idx |
|
# If we're faulty, we'll blame the coordinator. |
|
except FaultyCoordinatorError: |
|
assert i == faulty_idx |
|
return None |
|
return pre_finalize_rets |
|
|
|
|
|
def encpedpop_keys(seed: bytes) -> Tuple[bytes, bytes]: |
|
deckey = tagged_hash_bip_dkg("encpedpop deckey", seed) |
|
enckey = pubkey_gen_plain(deckey) |
|
return deckey, enckey |
|
|
|
|
|
def simulate_encpedpop( |
|
seeds, t, investigation: bool |
|
) -> Optional[List[Tuple[simplpedpop.DKGOutput, bytes]]]: |
|
n = len(seeds) |
|
enc_prets0 = [] |
|
enc_prets1 = [] |
|
for i in range(n): |
|
enc_prets0 += [encpedpop_keys(seeds[i])] |
|
|
|
enckeys = [pret[1] for pret in enc_prets0] |
|
for i in range(n): |
|
deckey = enc_prets0[i][0] |
|
random = random_bytes(32) |
|
enc_prets1 += [ |
|
encpedpop.participant_step1(seeds[i], deckey, enckeys, t, i, random) |
|
] |
|
|
|
pstates = [pstate for (pstate, _) in enc_prets1] |
|
pmsgs = [pmsg for (_, pmsg) in enc_prets1] |
|
if investigation: |
|
faulty_idx: List[int] = [] |
|
for i in range(n): |
|
# Let a random participant faulty_idx[i] send incorrect shares to i. |
|
faulty_idx[i:] = [randint(0, n - 1)] |
|
pmsgs[faulty_idx[i]].enc_shares[i] += Scalar(17) |
|
|
|
cmsg, cout, ceq, enc_secshares = encpedpop.coordinator_step(pmsgs, t, enckeys) |
|
pre_finalize_rets = [(cout, ceq)] |
|
for i in range(n): |
|
deckey = enc_prets0[i][0] |
|
try: |
|
pre_finalize_rets += [ |
|
encpedpop.participant_step2(pstates[i], deckey, cmsg, enc_secshares[i]) |
|
] |
|
except UnknownFaultyParticipantOrCoordinatorError as e: |
|
if not investigation: |
|
raise |
|
inv_msgs = encpedpop.coordinator_investigate(pmsgs) |
|
assert len(inv_msgs) == len(pmsgs) |
|
try: |
|
encpedpop.participant_investigate(e, inv_msgs[i]) |
|
# If we're not faulty, we should blame the faulty party. |
|
except FaultyParticipantOrCoordinatorError as e: |
|
assert i != faulty_idx[i] |
|
assert e.participant == faulty_idx[i] |
|
# If we're faulty, we'll blame the coordinator. |
|
except FaultyCoordinatorError: |
|
assert i == faulty_idx[i] |
|
return None |
|
return pre_finalize_rets |
|
|
|
|
|
def simulate_chilldkg( |
|
hostseckeys, t, investigation: bool |
|
) -> Optional[List[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]]: |
|
n = len(hostseckeys) |
|
|
|
hostpubkeys = [] |
|
for i in range(n): |
|
hostpubkeys += [chilldkg.hostpubkey_gen(hostseckeys[i])] |
|
|
|
params = chilldkg.SessionParams(hostpubkeys, t) |
|
|
|
prets1 = [] |
|
for i in range(n): |
|
random = random_bytes(32) |
|
prets1 += [chilldkg.participant_step1(hostseckeys[i], params, random)] |
|
|
|
pstates1 = [pret[0] for pret in prets1] |
|
pmsgs = [pret[1] for pret in prets1] |
|
if investigation: |
|
faulty_idx: List[int] = [] |
|
for i in range(n): |
|
# Let a random participant faulty_idx[i] send incorrect shares to i. |
|
faulty_idx[i:] = [randint(0, n - 1)] |
|
pmsgs[faulty_idx[i]].enc_pmsg.enc_shares[i] += Scalar(17) |
|
|
|
cstate, cmsg1 = chilldkg.coordinator_step1(pmsgs, params) |
|
|
|
prets2 = [] |
|
for i in range(n): |
|
try: |
|
prets2 += [chilldkg.participant_step2(hostseckeys[i], pstates1[i], cmsg1)] |
|
except UnknownFaultyParticipantOrCoordinatorError as e: |
|
if not investigation: |
|
raise |
|
inv_msgs = chilldkg.coordinator_investigate(pmsgs) |
|
assert len(inv_msgs) == len(pmsgs) |
|
try: |
|
chilldkg.participant_investigate(e, inv_msgs[i]) |
|
# If we're not faulty, we should blame the faulty party. |
|
except FaultyParticipantOrCoordinatorError as e: |
|
assert i != faulty_idx[i] |
|
assert e.participant == faulty_idx[i] |
|
# If we're faulty, we'll blame the coordinator. |
|
except FaultyCoordinatorError: |
|
assert i == faulty_idx[i] |
|
return None |
|
|
|
cmsg2, cout, crec = chilldkg.coordinator_finalize( |
|
cstate, [pret[1] for pret in prets2] |
|
) |
|
outputs = [(cout, crec)] |
|
for i in range(n): |
|
out = chilldkg.participant_finalize(prets2[i][0], cmsg2) |
|
assert out is not None |
|
outputs += [out] |
|
|
|
return outputs |
|
|
|
|
|
def simulate_chilldkg_full( |
|
hostseckeys, |
|
t, |
|
investigation: bool, |
|
) -> List[Optional[Tuple[chilldkg.DKGOutput, chilldkg.RecoveryData]]]: |
|
# Investigating is not supported by this wrapper |
|
assert not investigation |
|
|
|
hostpubkeys = [] |
|
n = len(hostseckeys) |
|
for i in range(n): |
|
hostpubkeys += [chilldkg.hostpubkey_gen(hostseckeys[i])] |
|
params = chilldkg.SessionParams(hostpubkeys, t) |
|
return simulate_chilldkg_full_example(hostseckeys, params, faulty_idx=None) |
|
|
|
|
|
def derive_interpolating_value(L, x_i): |
|
assert x_i in L |
|
assert all(L.count(x_j) <= 1 for x_j in L) |
|
lam = Scalar(1) |
|
for x_j in L: |
|
x_j = Scalar(x_j) |
|
x_i = Scalar(x_i) |
|
if x_j == x_i: |
|
continue |
|
lam *= x_j / (x_j - x_i) |
|
return lam |
|
|
|
|
|
def recover_secret(participant_indices, shares) -> Scalar: |
|
interpolated_shares = [] |
|
t = len(shares) |
|
assert len(participant_indices) == t |
|
for i in range(t): |
|
lam = derive_interpolating_value(participant_indices, participant_indices[i]) |
|
interpolated_shares += [(lam * shares[i])] |
|
recovered_secret = Scalar.sum(*interpolated_shares) |
|
return recovered_secret |
|
|
|
|
|
def test_recover_secret(): |
|
f = Polynomial([23, 42]) |
|
shares = [f(i) for i in [1, 2, 3]] |
|
assert recover_secret([1, 2], [shares[0], shares[1]]) == f.coeffs[0] |
|
assert recover_secret([1, 3], [shares[0], shares[2]]) == f.coeffs[0] |
|
assert recover_secret([2, 3], [shares[1], shares[2]]) == f.coeffs[0] |
|
|
|
|
|
def check_correctness_dkg_output(t, n, dkg_outputs: List[simplpedpop.DKGOutput]): |
|
assert len(dkg_outputs) == n + 1 |
|
secshares = [out[0] for out in dkg_outputs] |
|
threshold_pubkeys = [out[1] for out in dkg_outputs] |
|
pubshares = [out[2] for out in dkg_outputs] |
|
|
|
# Check that the threshold pubkey and pubshares are the same for the |
|
# coordinator (at [0]) and all participants (at [1:n + 1]). |
|
for i in range(n + 1): |
|
assert threshold_pubkeys[0] == threshold_pubkeys[i] |
|
assert len(pubshares[i]) == n |
|
assert pubshares[0] == pubshares[i] |
|
threshold_pubkey = threshold_pubkeys[0] |
|
|
|
# Check that the coordinator has no secret share |
|
assert secshares[0] is None |
|
|
|
# Check that each secshare matches the corresponding pubshare |
|
secshares_scalar = [ |
|
None if secshare is None else Scalar.from_bytes_checked(secshare) |
|
for secshare in secshares |
|
] |
|
for i in range(1, n + 1): |
|
assert secshares_scalar[i] * G == GE.from_bytes_compressed(pubshares[0][i - 1]) |
|
|
|
# Check that all combinations of t participants can recover the threshold pubkey |
|
for tsubset in combinations(range(1, n + 1), t): |
|
recovered = recover_secret(tsubset, [secshares_scalar[i] for i in tsubset]) |
|
assert recovered * G == GE.from_bytes_compressed(threshold_pubkey) |
|
|
|
|
|
@pytest.mark.parametrize('t,n,simulate_dkg,recovery,investigation', [ |
|
[1, 1, simulate_simplpedpop, False, False], |
|
[1, 1, simulate_simplpedpop, False, True], |
|
[1, 1, simulate_encpedpop, False, False], |
|
[1, 1, simulate_encpedpop, False, True], |
|
[1, 1, simulate_chilldkg, True, False], |
|
[1, 1, simulate_chilldkg, True, True], |
|
[1, 1, simulate_chilldkg_full, True, False], |
|
|
|
[1, 2, simulate_simplpedpop, False, False], |
|
[1, 2, simulate_simplpedpop, False, True], |
|
[1, 2, simulate_encpedpop, False, False], |
|
[1, 2, simulate_encpedpop, False, True], |
|
[1, 2, simulate_chilldkg, True, False], |
|
[1, 2, simulate_chilldkg, True, True], |
|
[1, 2, simulate_chilldkg_full, True, False], |
|
|
|
[2, 2, simulate_simplpedpop, False, False], |
|
[2, 2, simulate_simplpedpop, False, True], |
|
[2, 2, simulate_encpedpop, False, False], |
|
[2, 2, simulate_encpedpop, False, True], |
|
[2, 2, simulate_chilldkg, True, False], |
|
[2, 2, simulate_chilldkg, True, True], |
|
[2, 2, simulate_chilldkg_full, True, False], |
|
|
|
[2, 3, simulate_simplpedpop, False, False], |
|
[2, 3, simulate_simplpedpop, False, True], |
|
[2, 3, simulate_encpedpop, False, False], |
|
[2, 3, simulate_encpedpop, False, True], |
|
[2, 3, simulate_chilldkg, True, False], |
|
[2, 3, simulate_chilldkg, True, True], |
|
[2, 3, simulate_chilldkg_full, True, False], |
|
|
|
[2, 5, simulate_simplpedpop, False, False], |
|
[2, 5, simulate_simplpedpop, False, True], |
|
[2, 5, simulate_encpedpop, False, False], |
|
[2, 5, simulate_encpedpop, False, True], |
|
[2, 5, simulate_chilldkg, True, False], |
|
[2, 5, simulate_chilldkg, True, True], |
|
[2, 5, simulate_chilldkg_full, True, False], |
|
]) |
|
def test_correctness(t, n, simulate_dkg, recovery, investigation): |
|
seeds = [None] + [random_bytes(32) for _ in range(n)] |
|
|
|
rets = simulate_dkg(seeds[1:], t, investigation=investigation) |
|
if investigation: |
|
assert rets is None |
|
# The session has failed correctly, so there's nothing further to check. |
|
return |
|
|
|
# rets[0] are the return values from the coordinator |
|
# rets[1 : n + 1] are from the participants |
|
assert len(rets) == n + 1 |
|
dkg_outputs = [ret[0] for ret in rets] |
|
check_correctness_dkg_output(t, n, dkg_outputs) |
|
|
|
eqs_or_recs = [ret[1] for ret in rets] |
|
for i in range(1, n + 1): |
|
assert eqs_or_recs[0] == eqs_or_recs[i] |
|
|
|
if recovery: |
|
rec = eqs_or_recs[0] |
|
# Check correctness of chilldkg.recover |
|
for i in range(n + 1): |
|
(secshare, threshold_pubkey, pubshares), _ = chilldkg.recover(seeds[i], rec) |
|
assert secshare == dkg_outputs[i][0] |
|
assert threshold_pubkey == dkg_outputs[i][1] |
|
assert pubshares == dkg_outputs[i][2]
|
|
|