You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
349 lines
13 KiB
349 lines
13 KiB
# -*- coding: utf-8 -*- |
|
|
|
import asyncio |
|
import time |
|
|
|
from unittest import IsolatedAsyncioTestCase |
|
|
|
import jmclient # noqa: F401 install asyncioreactor |
|
|
|
import pytest |
|
|
|
import jmbitcoin as btc |
|
from jmbase import get_log |
|
from jmclient import ( |
|
load_test_config, jm_single, get_network, cryptoengine, VolatileStorage, |
|
FrostWallet, WalletService) |
|
from jmclient import FrostIPCServer, FrostIPCClient |
|
from jmclient.frost_clients import FROSTClient |
|
|
|
from test_frost_clients import populate_dkg_session |
|
|
|
|
|
pytestmark = pytest.mark.usefixtures("setup_regtest_frost_bitcoind") |
|
|
|
log = get_log() |
|
|
|
|
|
async def get_populated_wallet(entropy=None): |
|
storage = VolatileStorage() |
|
dkg_storage = VolatileStorage() |
|
recovery_storage = VolatileStorage() |
|
FrostWallet.initialize(storage, dkg_storage, recovery_storage, |
|
get_network(), entropy=entropy) |
|
wlt = FrostWallet(storage, dkg_storage, recovery_storage) |
|
await wlt.async_init(storage) |
|
return wlt |
|
|
|
|
|
class DummyFrostJMClientProtocol: |
|
|
|
def __init__(self, factory, client, nick): |
|
self.nick = nick |
|
self.factory = factory |
|
self.client = client |
|
self.party_clients = {} |
|
|
|
async def dkg_gen(self): |
|
log.debug('Coordinator call dkg_gen') |
|
client = self.factory.client |
|
md_type_idx = None |
|
session_id = None |
|
session = None |
|
|
|
while True: |
|
if md_type_idx is None: |
|
md_type_idx = await client.dkg_gen() |
|
if md_type_idx is None: |
|
log.debug('finished dkg_gen execution') |
|
break |
|
|
|
if session_id is None: |
|
session_id, _, session = self.dkg_init(*md_type_idx) |
|
if session_id is None: |
|
log.warning('could not get session_id from dkg_init}') |
|
await asyncio.sleep(5) |
|
continue |
|
|
|
pub = await client.wait_on_dkg_output(session_id) |
|
if not pub: |
|
session_id = None |
|
session = None |
|
continue |
|
|
|
if session.dkg_output: |
|
md_type_idx = None |
|
session_id = None |
|
session = None |
|
client.dkg_gen_list.pop(0) |
|
continue |
|
|
|
def dkg_init(self, mixdepth, address_type, index): |
|
log.debug(f'Coordinator call dkg_init ' |
|
f'({mixdepth}, {address_type}, {index})') |
|
client = self.factory.client |
|
hostpubkeyhash, session_id, sig = client.dkg_init( |
|
mixdepth, address_type, index) |
|
coordinator = client.dkg_coordinators.get(session_id) |
|
session = client.dkg_sessions.get(session_id) |
|
if session_id and session and coordinator: |
|
session.dkg_init_sec = time.time() |
|
|
|
for _, pc in self.party_clients.items(): |
|
|
|
async def on_dkg_init(pc, nick, hostpubkeyhash, |
|
session_id, sig): |
|
await pc.on_dkg_init( |
|
nick, hostpubkeyhash, session_id, sig) |
|
|
|
asyncio.create_task(on_dkg_init( |
|
pc, self.nick, hostpubkeyhash, session_id, sig)) |
|
return session_id, coordinator, session |
|
return None, None, None |
|
|
|
async def on_dkg_init(self, nick, hostpubkeyhash, session_id, sig): |
|
client = self.factory.client |
|
nick, hostpubkeyhash, session_id, sig, pmsg1 = client.on_dkg_init( |
|
nick, hostpubkeyhash, session_id, sig) |
|
if pmsg1: |
|
pc = self.party_clients[nick] |
|
session_id = bytes.fromhex(session_id) |
|
await pc.on_dkg_pmsg1( |
|
self.nick, hostpubkeyhash, session_id, sig, pmsg1) |
|
|
|
async def on_dkg_pmsg1(self, nick, hostpubkeyhash, session_id, sig, pmsg1): |
|
client = self.factory.client |
|
pmsg1 = client.deserialize_pmsg1(pmsg1) |
|
ready_nicks, cmsg1 = client.on_dkg_pmsg1( |
|
nick, hostpubkeyhash, session_id, sig, pmsg1) |
|
if ready_nicks and cmsg1: |
|
for party_nick in ready_nicks: |
|
pc = self.party_clients[party_nick] |
|
await pc.on_dkg_cmsg1(self.nick, session_id, cmsg1) |
|
|
|
async def on_dkg_cmsg1(self, nick, session_id, cmsg1): |
|
client = self.factory.client |
|
session = client.dkg_sessions.get(session_id) |
|
if not session: |
|
log.error(f'on_dkg_cmsg1: session {session_id} not found') |
|
return {'accepted': True} |
|
if session and session.coord_nick == nick: |
|
cmsg1 = client.deserialize_cmsg1(cmsg1) |
|
pmsg2 = client.party_step2(session_id, cmsg1) |
|
if pmsg2: |
|
pc = self.party_clients[nick] |
|
await pc.on_dkg_pmsg2(self.nick, session_id, pmsg2) |
|
else: |
|
log.error(f'on_dkg_cmsg1: not coordinator nick {nick}') |
|
|
|
async def on_dkg_pmsg2(self, nick, session_id, pmsg2): |
|
client = self.factory.client |
|
pmsg2 = client.deserialize_pmsg2(pmsg2) |
|
ready_nicks, cmsg2, ext_recovery = client.on_dkg_pmsg2( |
|
nick, session_id, pmsg2) |
|
if ready_nicks and cmsg2 and ext_recovery: |
|
for party_nick in ready_nicks: |
|
pc = self.party_clients[party_nick] |
|
await pc.on_dkg_cmsg2( |
|
self.nick, session_id, cmsg2, ext_recovery) |
|
|
|
async def on_dkg_cmsg2(self, nick, session_id, cmsg2, ext_recovery): |
|
client = self.factory.client |
|
session = client.dkg_sessions.get(session_id) |
|
if not session: |
|
log.error(f'on_dkg_cmsg2: session {session_id} not found') |
|
return {'accepted': True} |
|
if session and session.coord_nick == nick: |
|
cmsg2 = client.deserialize_cmsg2(cmsg2) |
|
finalized = client.finalize(session_id, cmsg2, ext_recovery) |
|
if finalized: |
|
pc = self.party_clients[nick] |
|
await pc.on_dkg_finalized(self.nick, session_id) |
|
else: |
|
log.error(f'on_dkg_cmsg2: not coordinator nick {nick}') |
|
|
|
async def on_dkg_finalized(self, nick, session_id): |
|
client = self.factory.client |
|
log.debug('Coordinator get dkgfinalized') |
|
client.on_dkg_finalized(nick, session_id) |
|
|
|
def frost_req(self, dkg_session_id, msg_bytes): |
|
log.debug('Coordinator call frost_req') |
|
client = self.factory.client |
|
hostpubkeyhash, sig, session_id = client.frost_req( |
|
dkg_session_id, msg_bytes) |
|
coordinator = client.frost_coordinators.get(session_id) |
|
session = client.frost_sessions.get(session_id) |
|
if session_id and session and coordinator: |
|
coordinator.frost_req_sec = time.time() |
|
for _, pc in self.party_clients.items(): |
|
|
|
async def on_frost_req(pc, nick, hostpubkeyhash, |
|
sig, session_id): |
|
await pc.on_frost_req( |
|
nick, hostpubkeyhash, sig, session_id) |
|
|
|
asyncio.create_task(on_frost_req( |
|
pc, self.nick, hostpubkeyhash, sig, session_id)) |
|
return session_id, coordinator, session |
|
|
|
async def on_frost_req(self, nick, hostpubkeyhash, sig, session_id): |
|
client = self.factory.client |
|
( |
|
nick2, |
|
hostpubkeyhash, |
|
sig, |
|
session_id, |
|
) = client.on_frost_req(nick, hostpubkeyhash, sig, session_id) |
|
if sig: |
|
pc = self.party_clients[nick] |
|
session_id = bytes.fromhex(session_id) |
|
await pc.on_frost_ack( |
|
self.nick, hostpubkeyhash, sig, session_id) |
|
|
|
async def on_frost_ack(self, nick, hostpubkeyhash, sig, session_id): |
|
client = self.factory.client |
|
assert client.on_frost_ack(nick, hostpubkeyhash, sig, session_id) |
|
pc = self.party_clients[nick] |
|
await pc.on_frost_init(self.nick, session_id) |
|
|
|
async def on_frost_init(self, nick, session_id): |
|
client = self.factory.client |
|
( |
|
nick2, |
|
session_id, |
|
pub_nonce |
|
) = client.on_frost_init(nick, session_id) |
|
if pub_nonce: |
|
pc = self.party_clients[nick] |
|
session_id = bytes.fromhex(session_id) |
|
await pc.on_frost_round1(self.nick, session_id, pub_nonce) |
|
|
|
async def on_frost_round1(self, nick, session_id, pub_nonce): |
|
client = self.factory.client |
|
( |
|
ready_nicks, |
|
nonce_agg, |
|
dkg_session_id, |
|
ids, |
|
msg |
|
) = client.on_frost_round1(nick, session_id, pub_nonce) |
|
if ready_nicks and nonce_agg: |
|
for party_nick in ready_nicks: |
|
pc = self.party_clients[nick] |
|
await pc.on_frost_agg1( |
|
self.nick, session_id, nonce_agg, dkg_session_id, ids, msg) |
|
|
|
async def on_frost_agg1(self, nick, session_id, |
|
nonce_agg, dkg_session_id, ids, msg): |
|
client = self.factory.client |
|
session = client.frost_sessions.get(session_id) |
|
if not session: |
|
log.error(f'on_frost_agg1: session {session_id} not found') |
|
return |
|
if session and session.coord_nick == nick: |
|
partial_sig = client.frost_round2( |
|
session_id, nonce_agg, dkg_session_id, ids, msg) |
|
if partial_sig: |
|
pc = self.party_clients[nick] |
|
await pc.on_frost_round2(self.nick, session_id, partial_sig) |
|
else: |
|
log.error(f'on_frost_agg1: not coordinator nick {nick}') |
|
|
|
async def on_frost_round2(self, nick, session_id, partial_sig): |
|
client = self.factory.client |
|
sig = client.on_frost_round2(nick, session_id, partial_sig) |
|
if sig: |
|
log.debug(f'Successfully get signature {sig.hex()[:8]}...') |
|
|
|
|
|
class DummyFrostJMClientProtocolFactory: |
|
|
|
protocol = DummyFrostJMClientProtocol |
|
|
|
def __init__(self, client, nick): |
|
self.client = client |
|
self.proto_client = self.protocol(self, self.client, nick) |
|
|
|
def add_party_client(self, nick, party_client): |
|
self.proto_client.party_clients[nick] = party_client |
|
|
|
def getClient(self): |
|
return self.proto_client |
|
|
|
|
|
class FrostIPCTestCaseBase(IsolatedAsyncioTestCase): |
|
|
|
def setUp(self): |
|
load_test_config(config_path='./test_frost') |
|
btc.select_chain_params("bitcoin/regtest") |
|
cryptoengine.BTC_P2TR.VBYTE = 100 |
|
jm_single().bc_interface.tick_forward_chain_interval = 2 |
|
|
|
async def asyncSetUp(self): |
|
self.nick1, self.nick2, self.nick3 = ['nick1', 'nick2', 'nick3'] |
|
entropy1 = bytes.fromhex('8e5e5677fb302874a607b63ad03ba434') |
|
entropy2 = bytes.fromhex('38dfa80fbb21b32b2b2740e00a47de9d') |
|
entropy3 = bytes.fromhex('3ad9c77fcd1d537b6ef396952d1221a0') |
|
self.wlt1 = await get_populated_wallet(entropy1) |
|
self.wlt_svc1 = WalletService(self.wlt1) |
|
self.fc1 = FROSTClient(self.wlt_svc1) |
|
cfactory1 = DummyFrostJMClientProtocolFactory(self.fc1, self.nick1) |
|
self.wlt1.set_client_factory(cfactory1) |
|
|
|
self.wlt2 = await get_populated_wallet(entropy2) |
|
self.wlt_svc2 = WalletService(self.wlt2) |
|
self.fc2 = FROSTClient(self.wlt_svc2) |
|
cfactory2 = DummyFrostJMClientProtocolFactory(self.fc2, self.nick2) |
|
self.wlt2.set_client_factory(cfactory2) |
|
|
|
self.wlt3 = await get_populated_wallet(entropy3) |
|
self.wlt_svc3 = WalletService(self.wlt3) |
|
self.fc3 = FROSTClient(self.wlt_svc3) |
|
cfactory3 = DummyFrostJMClientProtocolFactory(self.fc3, self.nick3) |
|
self.wlt3.set_client_factory(cfactory3) |
|
|
|
cfactory1.add_party_client(self.nick2, cfactory2.proto_client) |
|
cfactory1.add_party_client(self.nick3, cfactory3.proto_client) |
|
|
|
cfactory2.add_party_client(self.nick1, cfactory1.proto_client) |
|
cfactory2.add_party_client(self.nick3, cfactory3.proto_client) |
|
|
|
cfactory3.add_party_client(self.nick1, cfactory1.proto_client) |
|
cfactory3.add_party_client(self.nick2, cfactory2.proto_client) |
|
|
|
await populate_dkg_session(self) |
|
|
|
self.ipcs = FrostIPCServer(self.wlt1) |
|
await self.ipcs.async_init() |
|
self.ipcc = FrostIPCClient(self.wlt1) |
|
await self.ipcc.async_init() |
|
self.wlt1.set_ipc_client(self.ipcc) |
|
|
|
|
|
class FrostIPCClientTestCase(FrostIPCTestCaseBase): |
|
|
|
async def asyncSetUp(self): |
|
await super().asyncSetUp() |
|
self.serve_task = asyncio.create_task(self.ipcs.serve_forever()) |
|
|
|
async def asyncTearDown(self): |
|
self.serve_task.cancel("cancel from asyncTearDown") |
|
|
|
async def test_get_dkg_pubkey(self): |
|
pubkey = await self.ipcc.get_dkg_pubkey(0, 0, 0) |
|
dkg = self.wlt1.dkg |
|
pubkeys = list(dkg._dkg_pubkey.values()) |
|
assert pubkey and pubkey in pubkeys |
|
|
|
pubkey = await self.ipcc.get_dkg_pubkey(0, 0, 1) |
|
pubkeys = list(dkg._dkg_pubkey.values()) |
|
assert pubkey and pubkey in pubkeys |
|
|
|
async def test_frost_req(self): |
|
sighash = bytes.fromhex('01020304'*8) |
|
sig, pubkey, tweaked_pubkey = await self.ipcc.frost_req( |
|
0, 0, 0, sighash) |
|
assert sig and len(sig) == 64 |
|
assert pubkey and len(pubkey) == 33 |
|
assert tweaked_pubkey and len(tweaked_pubkey) == 33
|
|
|