You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

301 lines
8.9 KiB

#!/usr/bin/env python3
"""Example of a full ChillDKG session"""
from typing import Tuple, List, Optional
import asyncio
import pprint
from random import randint
from secrets import token_bytes as random_bytes
import sys
import argparse
from jmfrost.chilldkg_ref.chilldkg import (
params_id,
hostpubkey_gen,
participant_step1,
participant_step2,
participant_finalize,
participant_investigate,
coordinator_step1,
coordinator_finalize,
coordinator_investigate,
SessionParams,
DKGOutput,
RecoveryData,
FaultyParticipantOrCoordinatorError,
UnknownFaultyParticipantOrCoordinatorError,
)
#
# Network mocks to simulate full DKG sessions
#
class CoordinatorChannels:
def __init__(self, n):
self.n = n
self.queues = []
for i in range(n):
self.queues += [asyncio.Queue()]
def set_participant_queues(self, participant_queues):
self.participant_queues = participant_queues
def send_to(self, i, m):
assert self.participant_queues is not None
self.participant_queues[i].put_nowait(m)
def send_all(self, m):
assert self.participant_queues is not None
for i in range(self.n):
self.participant_queues[i].put_nowait(m)
async def receive_from(self, i):
item = await self.queues[i].get()
return item
class ParticipantChannel:
def __init__(self, coord_queue):
self.queue = asyncio.Queue()
self.coord_queue = coord_queue
# Send m to coordinator
def send(self, m):
self.coord_queue.put_nowait(m)
async def receive(self):
item = await self.queue.get()
return item
#
# Helper functions
#
def pphex(thing):
"""Pretty print an object with bytes as hex strings"""
def hexlify(thing):
if isinstance(thing, bytes):
return thing.hex()
if isinstance(thing, dict):
return {k: hexlify(v) for k, v in thing.items()}
if hasattr(thing, "_asdict"): # NamedTuple
return hexlify(thing._asdict())
if isinstance(thing, List):
return [hexlify(v) for v in thing]
return thing
pprint.pp(hexlify(thing))
#
# Protocol parties
#
async def participant(
chan: ParticipantChannel,
hostseckey: bytes,
params: SessionParams,
investigation_procedure: bool,
) -> Tuple[DKGOutput, RecoveryData]:
# TODO Top-level error handling
random = random_bytes(32)
state1, pmsg1 = participant_step1(hostseckey, params, random)
chan.send(pmsg1)
cmsg1 = await chan.receive()
# Participants can implement an optional investigation procedure. This
# allows the participant to determine which participant is faulty when an
# `UnknownFaultyParticipantOrCoordinatorError` is raised. The investiation
# procedure requires the participant to receive an extra "investigation
# message" from the coordinator that contains necessary information.
#
# In this example, if the investigation procedure is enabled, the
# participant expects the coordinator to send a investigation message.
# Alternatively, an implementation of the participant can explicitly request
# the investigation message only if participant_step2 fails.
if investigation_procedure:
cinv = await chan.receive()
try:
state2, eq_round1 = participant_step2(hostseckey, state1, cmsg1)
except UnknownFaultyParticipantOrCoordinatorError as e:
if investigation_procedure:
participant_investigate(e, cinv)
else:
# If this participant does not implement the investigation
# procedure, it cannot determine which party is faulty. Re-raise
# UnknownFaultyPartyError in this case.
raise
chan.send(eq_round1)
cmsg2 = await chan.receive()
return participant_finalize(state2, cmsg2)
async def coordinator(
chans: CoordinatorChannels, params: SessionParams, investigation_procedure: bool
) -> Tuple[DKGOutput, RecoveryData]:
(hostpubkeys, t) = params
n = len(hostpubkeys)
pmsgs1 = []
for i in range(n):
pmsgs1.append(await chans.receive_from(i))
state, cmsg1 = coordinator_step1(pmsgs1, params)
chans.send_all(cmsg1)
# If the coordinator implements the investigation procedure and it is
# enabled, it sends an extra message to the participants.
if investigation_procedure:
inv_msgs = coordinator_investigate(pmsgs1)
for i in range(n):
chans.send_to(i, inv_msgs[i])
sigs = []
for i in range(n):
sigs += [await chans.receive_from(i)]
cmsg2, dkg_output, recovery_data = coordinator_finalize(state, sigs)
chans.send_all(cmsg2)
return dkg_output, recovery_data
#
# DKG Session
#
# This is a dummy participant used to demonstrate the investigation procedure.
# It picks a random victim participant and sends an invalid share to it.
async def faulty_participant(
chan: ParticipantChannel, hostseckey: bytes, params: SessionParams, idx: int
):
random = random_bytes(32)
_, pmsg1 = participant_step1(hostseckey, params, random)
n = len(pmsg1.enc_pmsg.enc_shares)
# Pick random victim that is not this participant
victim = (idx + randint(1, n - 1)) % n
pmsg1.enc_pmsg.enc_shares[victim] += 17
chan.send(pmsg1)
def simulate_chilldkg_full(
hostseckeys: List[bytes], params: SessionParams, faulty_idx: Optional[int]
) -> List[Optional[Tuple[DKGOutput, RecoveryData]]]:
n = len(hostseckeys)
assert n == len(params.hostpubkeys)
# For demonstration purposes, we enable the investigation pro if a participant is
# faulty.
investigation_procedure = faulty_idx is not None
async def session():
coord_chans = CoordinatorChannels(n)
participant_chans = [
ParticipantChannel(coord_chans.queues[i]) for i in range(n)
]
coord_chans.set_participant_queues(
[participant_chans[i].queue for i in range(n)]
)
coroutines = [coordinator(coord_chans, params, investigation_procedure)] + [
participant(
participant_chans[i], hostseckeys[i], params, investigation_procedure
)
if i != faulty_idx
else faulty_participant(participant_chans[i], hostseckeys[i], params, i)
for i in range(n)
]
return await asyncio.gather(*coroutines)
outputs = asyncio.run(session())
return outputs
def main():
parser = argparse.ArgumentParser(description="ChillDKG example")
parser.add_argument(
"--faulty-participant",
action="store_true",
help="When this flag is set, one random participant will send an invalid message, and the investigation procedure will be enabled for other participants and the coordinator.",
)
parser.add_argument(
"t", nargs="?", type=int, default=2, help="Signing threshold [default = 2]"
)
parser.add_argument(
"n", nargs="?", type=int, default=3, help="Number of participants [default = 3]"
)
args = parser.parse_args()
t = args.t
n = args.n
if args.faulty_participant:
faulty_idx = randint(0, n - 1)
else:
faulty_idx = None
print("====== ChillDKG example session ======")
print(f"Using n = {n} participants and a threshold of t = {t}.")
if faulty_idx is not None:
print(f"Participant {faulty_idx} is faulty.")
print()
# Generate common inputs for all participants and coordinator
hostseckeys = [random_bytes(32) for _ in range(n)]
hostpubkeys = []
for i in range(n):
hostpubkeys += [hostpubkey_gen(hostseckeys[i])]
params = SessionParams(hostpubkeys, t)
print("=== Host secret keys ===")
pphex(hostseckeys)
print()
print("=== Session parameters ===")
pphex(params)
print()
print(f"Session parameters identifier: {params_id(params).hex()}")
print()
try:
rets = simulate_chilldkg_full(hostseckeys, params, faulty_idx)
except FaultyParticipantOrCoordinatorError as e:
print(
f"A participant has failed and is blaming either participant {e.participant} or the coordinator."
)
# If the blamed participant is the faulty participant, exit with code 0.
# Otherwise, re-raise the exception.
if faulty_idx == e.participant:
return 0
else:
raise
assert len(rets) == n + 1
print("=== Coordinator's DKG output ===")
dkg_output, _ = rets[0]
pphex(dkg_output)
print()
for i in range(n):
print(f"=== Participant {i}'s DKG output ===")
dkg_output, _ = rets[i + 1]
pphex(dkg_output)
print()
# Check that all RecoveryData of all parties is identical
assert len(set([rets[i][1] for i in range(n + 1)])) == 1
recovery_data = rets[0][1]
print(f"=== Common recovery data ({len(recovery_data)} bytes) ===")
print(recovery_data.hex())
if __name__ == "__main__":
sys.exit(main())