diff --git a/conftest.py b/conftest.py index f6f4f71..044a6b3 100644 --- a/conftest.py +++ b/conftest.py @@ -264,6 +264,13 @@ def setup_regtest_frost_bitcoind(pytestconfig): "must be 27 or greater.\n") local_command(create_wallet) local_command(f'{root_cmd} loadwallet {wallet_name} true true') + for i in range(2): + cpe = local_command(f'{root_cmd} -rpcwallet={wallet_name} getnewaddress') + if cpe.returncode != 0: + pytest.exit(f"Cannot setup tests, bitcoin-cli failing.\n{cpe.stdout.decode('utf-8')}") + destn_addr = cpe.stdout[:-1].decode('utf-8') + local_command(f'{root_cmd} -rpcwallet={wallet_name} generatetoaddress 1 {destn_addr}') + sleep(1) yield # shut down bitcoind local_command(stop_cmd) diff --git a/src/jmbase/commands.py b/src/jmbase/commands.py index a64309e..ba7f149 100644 --- a/src/jmbase/commands.py +++ b/src/jmbase/commands.py @@ -127,20 +127,32 @@ class JMDKGFinalized(JMCommand): """Messages used by FROST parties""" - -class JMFROSTInit(JMCommand): +class JMFROSTReq(JMCommand): arguments = [ (b'hostpubkeyhash', Unicode()), + (b'sig', Unicode()), (b'session_id', Unicode()), + ] + +class JMFROSTAck(JMCommand): + arguments = [ + (b'nick', Unicode()), + (b'hostpubkeyhash', Unicode()), (b'sig', Unicode()), + (b'session_id', Unicode()), + ] + +class JMFROSTInit(JMCommand): + arguments = [ + (b'nick', Unicode()), + (b'session_id', Unicode()), ] class JMFROSTRound1(JMCommand): arguments = [ (b'nick', Unicode()), - (b'hostpubkeyhash', Unicode()), (b'session_id', Unicode()), - (b'sig', Unicode()), + (b'hostpubkeyhash', Unicode()), (b'pub_nonce', Unicode()), ] @@ -327,21 +339,33 @@ class JMDKGCMsg2Seen(JMCommand): """Messages used by FROST parties""" - -class JMFROSTInitSeen(JMCommand): +class JMFROSTReqSeen(JMCommand): arguments = [ (b'nick', Unicode()), (b'hostpubkeyhash', Unicode()), + (b'sig', Unicode()), (b'session_id', Unicode()), + ] + +class JMFROSTAckSeen(JMCommand): + arguments = [ + (b'nick', Unicode()), + (b'hostpubkeyhash', Unicode()), (b'sig', Unicode()), + (b'session_id', Unicode()), + ] + +class JMFROSTInitSeen(JMCommand): + arguments = [ + (b'nick', Unicode()), + (b'session_id', Unicode()), ] class JMFROSTRound1Seen(JMCommand): arguments = [ (b'nick', Unicode()), - (b'hostpubkeyhash', Unicode()), (b'session_id', Unicode()), - (b'sig', Unicode()), + (b'hostpubkeyhash', Unicode()), (b'pub_nonce', Unicode()), ] diff --git a/src/jmclient/client_protocol.py b/src/jmclient/client_protocol.py index e0dd62d..a08b2f2 100644 --- a/src/jmclient/client_protocol.py +++ b/src/jmclient/client_protocol.py @@ -583,45 +583,76 @@ class JMClientProtocol(BaseClientProtocol): """FROST specifics """ - def frost_init(self, dkg_session_id, msg_bytes): - jlog.debug(f'Coordinator call frost_init') + def frost_req(self, dkg_session_id, msg_bytes): + jlog.debug(f'Coordinator call frost_req') client = self.factory.client - hostpubkeyhash, session_id, sig = client.frost_init( + hostpubkeyhash, sig, session_id = client.frost_req( dkg_session_id, msg_bytes) coordinator = client.frost_coordinators.get(session_id) session = client.frost_sessions.get(session_id) if session_id and session and coordinator: - d = self.callRemote(commands.JMFROSTInit, - hostpubkeyhash=hostpubkeyhash, - session_id=bintohex(session_id), - sig=sig) + d = self.callRemote(commands.JMFROSTReq, + hostpubkeyhash=hostpubkeyhash, sig=sig, + session_id=bintohex(session_id)) self.defaultCallbacks(d) - coordinator.frost_init_sec = time.time() + coordinator.frost_req_sec = time.time() return session_id, coordinator, session return None, None, None + @commands.JMFROSTReqSeen.responder + def on_JM_FROST_REQ_SEEN(self, nick, hostpubkeyhash, sig, session_id): + wallet = self.client.wallet_service.wallet + if not isinstance(wallet, FrostWallet) or wallet._dkg is None: + return {'accepted': True} + + client = self.factory.client + session_id = hextobin(session_id) + nick, hostpubkeyhash, sig, session_id = \ + client.on_frost_req(nick, hostpubkeyhash, sig, session_id) + if sig: + d = self.callRemote(commands.JMFROSTAck, + nick=nick, + hostpubkeyhash=hostpubkeyhash, sig=sig, + session_id=session_id) + self.defaultCallbacks(d) + return {'accepted': True} + + @commands.JMFROSTAckSeen.responder + def on_JM_FROST_ACK_SEEN(self, nick, hostpubkeyhash, sig, session_id): + wallet = self.client.wallet_service.wallet + if not isinstance(wallet, FrostWallet) or wallet._dkg is None: + return {'accepted': True} + + client = self.factory.client + bin_session_id = hextobin(session_id) + if client.on_frost_ack(nick, hostpubkeyhash, sig, bin_session_id): + d = self.callRemote(commands.JMFROSTInit, + nick=nick, session_id=session_id) + self.defaultCallbacks(d) + return {'accepted': True} + @commands.JMFROSTInitSeen.responder - def on_JM_FROST_INIT_SEEN(self, nick, hostpubkeyhash, session_id, sig): + def on_JM_FROST_INIT_SEEN(self, nick, session_id): wallet = self.client.wallet_service.wallet if not isinstance(wallet, FrostWallet) or wallet._dkg is None: return {'accepted': True} client = self.factory.client session_id = hextobin(session_id) - nick, hostpubkeyhash, session_id, sig, pub_nonce = \ - client.on_frost_init(nick, hostpubkeyhash, session_id, sig) + nick, session_id, pubkeyhash, pub_nonce = \ + client.on_frost_init(nick, session_id) if pub_nonce: pub_nonce_b64 = base64.b64encode(pub_nonce).decode('ascii') d = self.callRemote(commands.JMFROSTRound1, - nick=nick, hostpubkeyhash=hostpubkeyhash, - session_id=session_id, sig=sig, + nick=nick, session_id=session_id, + hostpubkeyhash=pubkeyhash, pub_nonce=pub_nonce_b64) self.defaultCallbacks(d) return {'accepted': True} @commands.JMFROSTRound1Seen.responder - def on_JM_FROST_ROUND1_SEEN(self, nick, hostpubkeyhash, - session_id, sig, pub_nonce): + def on_JM_FROST_ROUND1_SEEN(self, nick, session_id, + hostpubkeyhash, pub_nonce): wallet = self.client.wallet_service.wallet if not isinstance(wallet, FrostWallet) or wallet._dkg is None: return {'accepted': True} @@ -630,8 +661,8 @@ class JMClientProtocol(BaseClientProtocol): bin_session_id = hextobin(session_id) pub_nonce = base64.b64decode(pub_nonce) ready_nicks, nonce_agg, dkg_session_id, ids, msg = \ - client.on_frost_round1(nick, hostpubkeyhash, bin_session_id, - sig, pub_nonce) + client.on_frost_round1( + nick, bin_session_id, hostpubkeyhash, pub_nonce) if ready_nicks and nonce_agg: for nick in ready_nicks: self.frost_agg1(nick, session_id, nonce_agg, diff --git a/src/jmclient/cryptoengine.py b/src/jmclient/cryptoengine.py index 2ce2757..1f7f172 100644 --- a/src/jmclient/cryptoengine.py +++ b/src/jmclient/cryptoengine.py @@ -542,7 +542,7 @@ class BTC_P2TR_FROST(BTC_P2TR): spent_outputs = kwargs['spent_outputs'] sighash = SignatureHashSchnorr(tx, i, spent_outputs) mixdepth, address_type, index = wallet.get_details(path) - sig, pubkey, tweaked_pubkey = await wallet.ipc_client.frost_sign( + sig, pubkey, tweaked_pubkey = await wallet.ipc_client.frost_req( mixdepth, address_type, index, sighash) if not sig: return None, "FROST signing failed" diff --git a/src/jmclient/frost_clients.py b/src/jmclient/frost_clients.py index fe1fd44..4780e40 100644 --- a/src/jmclient/frost_clients.py +++ b/src/jmclient/frost_clients.py @@ -665,7 +665,7 @@ class FROSTCoordinator: def __init__(self, *, session_id, hostpubkey, dkg_session_id, msg): self.session_id = session_id - self.frost_init_sec = 0 + self.frost_req_sec = 0 self.hostpubkey = hostpubkey self.dkg_session_id = dkg_session_id self.msg = msg @@ -681,7 +681,7 @@ class FROSTCoordinator: def __repr__(self): return (f'FROSTCoordinator(session_id={self.session_id}, ' - f'frost_init_sec={self.frost_init_sec}, ' + f'frost_req_sec={self.frost_req_sec}, ' f'hostpubkey={self.hostpubkey}, ' f'dkg_session_id={self.dkg_session_id}, ' f'msg={self.msg}, ' @@ -726,7 +726,7 @@ class FROSTClient(DKGClient): self.frost_coordinators = dict() self.frost_sessions = dict() - def frost_init(self, dkg_session_id, msg_bytes): + def frost_req(self, dkg_session_id, msg_bytes): try: wallet = self.wallet_service.wallet hostseckey = wallet._hostseckey[:32] @@ -760,18 +760,18 @@ class FROSTClient(DKGClient): coordinator.sessions[hostpubkey]['pub_nonce'] = pub_nonce coin_key = CCoinKey.from_secret_bytes(hostseckey) sig = coin_key.sign_schnorr_no_tweak(session_id) - return hostpubkeyhash.hex(), session_id, sig.hex() + return hostpubkeyhash.hex(), sig.hex(), session_id except Exception as e: - jlog.error(f'frost_init: {repr(e)}') + jlog.error(f'frost_req: {repr(e)}') return None, None, None - def on_frost_init(self, nick, pubkeyhash, session_id, sig): + def on_frost_req(self, nick, hostpubkeyhash, sig, session_id): try: if session_id in self.frost_sessions: raise Exception(f'session {session_id.hex()} already exists') - pubkey = self.find_pubkey_by_pubkeyhash(pubkeyhash) + pubkey = self.find_pubkey_by_pubkeyhash(hostpubkeyhash) if not pubkey: - raise Exception(f'pubkey for {pubkeyhash.hex()} not found') + raise Exception(f'pubkey for {hostpubkeyhash} not found') xpubkey = XOnlyPubKey(pubkey[1:]) if not xpubkey.verify_schnorr(session_id, hextobin(sig)): raise Exception('signature verification failed') @@ -784,7 +784,7 @@ class FROSTClient(DKGClient): self.my_id = i break assert self.my_id is not None - hostpubkeyhash = sha256(hostpubkey).digest() + my_hostpubkeyhash = sha256(hostpubkey).digest() session = FROSTSession(session_id=session_id, hostpubkey=hostpubkey, coord_nick=nick, @@ -792,12 +792,41 @@ class FROSTClient(DKGClient): self.frost_sessions[session_id] = session coin_key = CCoinKey.from_secret_bytes(hostseckey) sig = coin_key.sign_schnorr_no_tweak(session_id) + return (nick, my_hostpubkeyhash.hex(), sig.hex(), session_id.hex()) + except Exception as e: + jlog.error(f'on_frost_req: {repr(e)}') + return None, None, None, None + + def on_frost_ack(self, nick, hostpubkeyhash, sig, session_id): + try: + pubkey = self.find_pubkey_by_pubkeyhash(hostpubkeyhash) + if not pubkey: + raise Exception(f'pubkey for {hostpubkeyhash} not found') + xpubkey = XOnlyPubKey(pubkey[1:]) + if not xpubkey.verify_schnorr(session_id, hextobin(sig)): + raise Exception('signature verification failed') + return True + except Exception as e: + jlog.error(f'on_frost_ack: {repr(e)}') + return False + + def on_frost_init(self, nick, session_id): + try: + session = self.frost_sessions.get(session_id) + if not session: + raise Exception(f'session {session_id.hex()} not found') + if session.sec_nonce: + raise Exception(f'session.sec_nonce already set ' + f'for {session_id.hex()}') + wallet = self.wallet_service.wallet + hostseckey = wallet._hostseckey[:32] + hostpubkey = hostpubkey_gen(hostseckey) + pubkeyhash = sha256(hostpubkey).digest() pub_nonce = self.frost_round1(session_id) - return (nick, hostpubkeyhash.hex(), session_id.hex(), - sig.hex(), pub_nonce) + return (nick, session_id.hex(), pubkeyhash.hex(), pub_nonce) except Exception as e: jlog.error(f'on_frost_init: {repr(e)}') - return None, None, None, None, None + return None, None, None, None def frost_round1(self, session_id): try: @@ -815,7 +844,7 @@ class FROSTClient(DKGClient): except Exception as e: jlog.error(f'frost_round1: {repr(e)}') - def on_frost_round1(self, nick, pubkeyhash, session_id, sig, pub_nonce): + def on_frost_round1(self, nick, session_id, pubkeyhash, pub_nonce): try: coordinator = self.frost_coordinators.get(session_id) if not coordinator: @@ -827,9 +856,6 @@ class FROSTClient(DKGClient): pubkey = self.find_pubkey_by_pubkeyhash(pubkeyhash) if not pubkey: raise Exception(f'pubkey for {pubkeyhash} not found') - xpubkey = XOnlyPubKey(pubkey[1:]) - if not xpubkey.verify_schnorr(session_id, hextobin(sig)): - raise Exception(f'signature verification failed') if pubkey in coordinator.parties: jlog.debug(f'pubkey {pubkey.hex()} already in' f' coordinator parties') @@ -1003,7 +1029,7 @@ class FROSTClient(DKGClient): await asyncio.sleep(1) if coordinator.sig: break - waiting_sec = time.time() - coordinator.frost_init_sec + waiting_sec = time.time() - coordinator.frost_req_sec if waiting_sec > self.FROST_WAIT_SEC: raise Exception(f'timed out FROST session ' f'{session_id.hex()}') diff --git a/src/jmclient/frost_ipc.py b/src/jmclient/frost_ipc.py index 519494c..05955d7 100644 --- a/src/jmclient/frost_ipc.py +++ b/src/jmclient/frost_ipc.py @@ -77,9 +77,9 @@ class FrostIPCServer(IPCBase): if cmd == 'get_dkg_pubkey': task = self.loop.create_task( self.on_get_dkg_pubkey(msg_id, *data)) - elif cmd == 'frost_sign': + elif cmd == 'frost_req': task = self.loop.create_task( - self.on_frost_sign(msg_id, *data)) + self.on_frost_req(msg_id, *data)) if task: self.tasks.add(task) except Exception as e: @@ -118,7 +118,7 @@ class FrostIPCServer(IPCBase): except Exception as e: jlog.error(f'FrostIPCServer.send_dkg_pubkey: {repr(e)}') - async def on_frost_sign(self, msg_id, mixdepth, address_type, index, + async def on_frost_req(self, msg_id, mixdepth, address_type, index, sighash): try: wallet = self.wallet @@ -126,12 +126,12 @@ class FrostIPCServer(IPCBase): frost_client = wallet.client_factory.client dkg = wallet.dkg dkg_session_id = dkg.find_session(mixdepth, address_type, index) - session_id, _, _ = client.frost_init(dkg_session_id, sighash) + session_id, _, _ = client.frost_req(dkg_session_id, sighash) sig, tweaked_pubkey = await frost_client.wait_on_sig(session_id) pubkey = dkg.find_dkg_pubkey(mixdepth, address_type, index) await self.send_frost_sig(msg_id, sig, pubkey, tweaked_pubkey) except Exception as e: - jlog.error(f'FrostIPCServer.on_frost_sign: {repr(e)}') + jlog.error(f'FrostIPCServer.on_frost_req: {repr(e)}') await self.send_frost_sig(msg_id, None, None, None) async def send_frost_sig(self, msg_id, sig, pubkey, tweaked_pubkey): @@ -232,15 +232,15 @@ class FrostIPCClient(IPCBase): except Exception as e: jlog.error(f'FrostIPCClient.get_dkg_pubkey: {repr(e)}') - async def frost_sign(self, mixdepth, address_type, index, sighash): - jlog.debug(f'FrostIPCClient.frost_sign for mixdepth={mixdepth}, ' + async def frost_req(self, mixdepth, address_type, index, sighash): + jlog.debug(f'FrostIPCClient.frost_req for mixdepth={mixdepth}, ' f'address_type={address_type}, index={index}, ' f'sighash={sighash.hex()}') try: self.msg_id += 1 msg_dict = { 'msg_id': self.msg_id, - 'cmd': 'frost_sign', + 'cmd': 'frost_req', 'data': (mixdepth, address_type, index, sighash), } self.sw.write(self.encrypt_msg(msg_dict)) @@ -251,16 +251,16 @@ class FrostIPCClient(IPCBase): sig, pubkey, tweaked_pubkey = fut.result() if sig is None: jlog.error( - f'FrostIPCClient.frost_sign got None sig value from ' + f'FrostIPCClient.frost_req got None sig value from ' f'FrostIPCServer for mixdepth={mixdepth}, ' f'address_type={address_type}, index={index}, ' f'sighash={sighash.hex()}') return sig, pubkey, tweaked_pubkey jlog.debug( - f'FrostIPCClient.frost_sign successfully got signature ' + f'FrostIPCClient.frost_req successfully got signature ' f'for mixdepth={mixdepth}, address_type={address_type}, ' f'index={index}, sighash={sighash.hex()}') return sig, pubkey, tweaked_pubkey except Exception as e: - jlog.error(f'FrostIPCClient.frost_sign: {repr(e)}') + jlog.error(f'FrostIPCClient.frost_req: {repr(e)}') return None, None, None diff --git a/src/jmclient/maker.py b/src/jmclient/maker.py index 72c5eb0..f0a5d36 100644 --- a/src/jmclient/maker.py +++ b/src/jmclient/maker.py @@ -127,7 +127,7 @@ class Maker(object): path = wallet.addr_to_path(auth_address) md, address_type, index = wallet.get_details(path) kphex_hash = hashlib.sha256(bintohex(kphex).encode()).digest() - sig, _, tweaked_pubkey = await wallet.ipc_client.frost_sign( + sig, _, tweaked_pubkey = await wallet.ipc_client.frost_req( md, address_type, index, kphex_hash) sig = base64.b64encode(sig).decode('ascii') if not sig: diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index a78ea2b..bc6d674 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -1962,7 +1962,7 @@ async def wallet_tool_main(wallet_root_path): msg = 'testmsg' md = address_type = index = 0 msghash = sha256(msg.encode()).digest() - sig, pubkey, tweaked_pubkey = await wallet.ipc_client.frost_sign( + sig, pubkey, tweaked_pubkey = await wallet.ipc_client.frost_req( md, address_type, index, msghash) verify_pubkey = XOnlyPubKey(tweaked_pubkey[1:]) if verify_pubkey.verify_schnorr(msghash, sig): diff --git a/src/jmdaemon/daemon_protocol.py b/src/jmdaemon/daemon_protocol.py index 30cb96a..80be05b 100644 --- a/src/jmdaemon/daemon_protocol.py +++ b/src/jmdaemon/daemon_protocol.py @@ -26,9 +26,11 @@ from twisted.web import server from txtorcon.socks import HostUnreachableError from twisted.python import log import urllib.parse as urlparse +from collections import defaultdict from urllib.parse import urlencode import json import threading +import time import os from io import BytesIO import copy @@ -485,6 +487,32 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): self.use_fidelity_bond = False self.offerlist = None self.kp = None + self.frost_crypto_boxes = {} + self.frost_expected_msgs = defaultdict(lambda: defaultdict(dict)) + self.frost_cleanup_loop = task.LoopingCall(self.frost_cleanup) + + def frost_cleanup(self): + now = time.time() + boxes = self.frost_crypto_boxes + cleanup_list = [] + for nick, sessions in boxes.items(): + for session_id, box in sessions.items(): + if now - box['created'] > 120: + cleanup_list.append((nick, session_id)) + for nick, session_id in cleanup_list: + boxes[nick].pop(session_id) + if not boxes[nick]: + boxes.pop(nick) + msgs = self.frost_expected_msgs + cleanup_list = [] + for nick, cmds in msgs.items(): + for cmd, cmd_data in cmds.items(): + if now - cmd_data['created'] > 120: + cleanup_list.append((nick, cmd)) + for nick, cmd in cleanup_list: + msgs[nick].pop(cmd) + if not msgs[nick]: + msgs.pop(nick) def checkClientResponse(self, response): """A generic check of client acceptance; any failure @@ -555,6 +583,8 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): self.on_dkgfinalized, self.on_dkgcmsg1, self.on_dkgcmsg2, + self.on_frostreq, + self.on_frostack, self.on_frostinit, self.on_frostround1, self.on_frostround2, @@ -582,7 +612,7 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): assert self.jm_state == 0 self.role = role self.crypto_boxes = {} - self.kp = init_keypair() + self.kp = init_keypair() # FIXME not used by maker, mv to taker code? d = self.callRemote(JMSetupDone) self.defaultCallbacks(d) #Request orderbook here, on explicit setup request from client, @@ -658,29 +688,114 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): """FROST specific responders """ + @JMFROSTReq.responder + def on_JM_FROST_REQ(self, hostpubkeyhash, sig, session_id): + if not self.frost_cleanup_loop.running: + self.frost_cleanup_loop.start(30.0) + boxes = self.frost_crypto_boxes + nick_boxes = boxes.get(None, {}) # None for self + session_box = nick_boxes.get(session_id, {}) + if session_box: + log.msg(f'on_JM_FROST_REQ: session_id "{session_id}" ' + f'setup is incorrect. ' + f'FROST request aborted.') + return {'accepted': True} + kp = init_keypair() + session_box['kp'] = kp + session_box['created'] = time.time() + nick_boxes[session_id] = session_box + boxes[None] = nick_boxes # None for self + dh_pubk = kp.hex_pk().decode('ascii') + req_msg = f'!frostreq {hostpubkeyhash} {sig} {session_id} {dh_pubk}' + self.mcc.pubmsg(req_msg) + return {'accepted': True} + + @JMFROSTAck.responder + def on_JM_FROST_ACK(self, nick, hostpubkeyhash, sig, session_id): + boxes = self.frost_crypto_boxes + self_boxes = boxes.get(None, {}) # None for self + session_box = self_boxes.get(session_id, {}) + if session_box: + log.msg(f'on_JM_FROST_ACK: session_id "{session_id}" ' + f'setup is incorrect. ' + f'FROST request aborted.') + return {'accepted': True} + kp = init_keypair() + session_box['kp'] = kp + session_box['created'] = time.time() + self_boxes[session_id] = session_box + boxes[None] = self_boxes # None for self + + nick_boxes = boxes.get(nick, {}) + nick_session_box = nick_boxes.get(session_id, {}) + if not nick_session_box: + log.msg(f'on_JM_FROST_ACK: nick {nick}, session_id "{session_id}" ' + f'setup is incorrect. ' + f'FROST request aborted.') + return {'accepted': True} + try: + nick_dh_pubk = nick_session_box['dh_pubk'] + nick_session_box['crypto_box'] = as_init_encryption( + kp, init_pubkey(nick_dh_pubk)) + except NaclError as e: + log.msg('on frostround1: error creating crypto_box. ' + 'FROST session aborted') + return {'accepted': True} + + self.frost_expected_msgs[nick]['frostinit'] = { + 'session_id': session_id, + 'created': time.time(), + } + dh_pubk = kp.hex_pk().decode('ascii') + ack_msg = f'{hostpubkeyhash} {sig} {session_id} {dh_pubk}' + self.mcc.prepare_privmsg(nick, 'frostack', ack_msg) + return {'accepted': True} + @JMFROSTInit.responder - def on_JM_FROST_INIT(self, hostpubkeyhash, session_id, sig): - self.mcc.pubmsg(f'!frostinit {hostpubkeyhash} {session_id} {sig}') + def on_JM_FROST_INIT(self, nick, session_id): + self.frost_expected_msgs[nick]['frostround1'] = { + 'session_id': session_id, + 'created': time.time(), + } + init_msg = f'{session_id}' + self.mcc.prepare_privmsg(nick, 'frostinit', init_msg) return {'accepted': True} @JMFROSTRound1.responder - def on_JM_FROST_ROUND1(self, nick, hostpubkeyhash, - session_id, sig, pub_nonce): - msg = f'{hostpubkeyhash} {session_id} {sig} {pub_nonce}' - self.mcc.prepare_privmsg(nick, "frostround1", msg) + def on_JM_FROST_ROUND1(self, nick, hostpubkeyhash, session_id, pub_nonce): + self.frost_expected_msgs[nick]['frostagg1'] = { + 'session_id': session_id, + 'created': time.time(), + } + round1_msg = f'{session_id} {hostpubkeyhash} {pub_nonce}' + self.mcc.prepare_privmsg(nick, "frostround1", round1_msg) return {'accepted': True} @JMFROSTAgg1.responder def on_JM_FROST_AGG1(self, nick, session_id, nonce_agg, dkg_session_id, ids, msg): - msg = f'{session_id} {nonce_agg} {dkg_session_id} {ids} {msg}' - self.mcc.prepare_privmsg(nick, "frostagg1", msg) + self.frost_expected_msgs[nick]['frostround2'] = { + 'session_id': session_id, + 'created': time.time(), + } + agg1_msg = f'{session_id} {nonce_agg} {dkg_session_id} {ids} {msg}' + self.mcc.prepare_privmsg(nick, "frostagg1", agg1_msg) return {'accepted': True} @JMFROSTRound2.responder def on_JM_FROST_ROUND2(self, nick, session_id, partial_sig): msg = f'{session_id} {partial_sig}' self.mcc.prepare_privmsg(nick, "frostround2", msg) + # cleanup frost_crypto_boxes + boxes = self.frost_crypto_boxes + cleanup_list = [] + for nick, sessions in boxes.items(): + if session_id in sessions: + cleanup_list.append((nick, session_id)) + for nick, session_id in cleanup_list: + boxes[nick].pop(session_id) + if not boxes[nick]: + boxes.pop(nick) return {'accepted': True} @@ -848,16 +963,60 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): ext_recovery=ext_recovery) self.defaultCallbacks(d) - def on_frostinit(self, nick, hostpubkeyhash, session_id, sig): - d = self.callRemote(JMFROSTInitSeen, + def on_frostreq(self, nick, hostpubkeyhash, sig, session_id, dh_pubk): + boxes = self.frost_crypto_boxes + nick_boxes = boxes.get(nick, {}) + session_box = nick_boxes.get(session_id, {}) + if not session_box and not 'dh_pubk' in session_box: + session_box['dh_pubk'] = dh_pubk + session_box['created'] = time.time() + nick_boxes[session_id] = session_box + boxes[nick] = nick_boxes + + d = self.callRemote(JMFROSTReqSeen, + nick=nick, hostpubkeyhash=hostpubkeyhash, + sig=sig, session_id=session_id) + self.defaultCallbacks(d) + + def on_frostack(self, nick, hostpubkeyhash, sig, session_id, dh_pubk): + boxes = self.frost_crypto_boxes + nick_boxes = boxes.get(nick, {}) + session_box = nick_boxes.get(session_id, {}) + if not session_box and not 'dh_pubk' in session_box: + session_box['dh_pubk'] = dh_pubk + session_box['created'] = time.time() + nick_boxes[session_id] = session_box + boxes[nick] = nick_boxes + + self_boxes = boxes.get(None, {}) + self_session_box = self_boxes.get(session_id, {}) + if not self_session_box or not 'kp' in self_session_box: + log.msg(f'on_frostack: session_id "{session_id}" ' + f' setup is incorrect. ' + f'FROST session aborted.') + return + try: + kp = self_session_box['kp'] + session_box['crypto_box'] = as_init_encryption( + kp, init_pubkey(dh_pubk)) + except NaclError as e: + log.msg('on frostround1: error creating crypto_box. ' + 'FROST session aborted') + return + d = self.callRemote(JMFROSTAckSeen, nick=nick, hostpubkeyhash=hostpubkeyhash, - session_id=session_id, sig=sig) + sig=sig, session_id=session_id) self.defaultCallbacks(d) - def on_frostround1(self, nick, hostpubkeyhash, session_id, sig, pub_nonce): + def on_frostinit(self, nick, session_id): + d = self.callRemote(JMFROSTInitSeen, + nick=nick, session_id=session_id) + self.defaultCallbacks(d) + + def on_frostround1(self, nick, session_id, hostpubkeyhash, pub_nonce): d = self.callRemote(JMFROSTRound1Seen, - nick=nick, hostpubkeyhash=hostpubkeyhash, - session_id=session_id, sig=sig, + nick=nick, session_id=session_id, + hostpubkeyhash=hostpubkeyhash, pub_nonce=pub_nonce) self.defaultCallbacks(d) @@ -866,6 +1025,16 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): nick=nick, session_id=session_id, partial_sig=partial_sig) self.defaultCallbacks(d) + # cleanup frost_crypto_boxes + boxes = self.frost_crypto_boxes + cleanup_list = [] + for nick, sessions in boxes.items(): + if session_id in sessions: + cleanup_list.append((nick, session_id)) + for nick, session_id in cleanup_list: + boxes[nick].pop(session_id) + if not boxes[nick]: + boxes.pop(nick) def on_frostagg1(self, nick, session_id, nonce_agg, dkg_session_id, ids, msg): @@ -1169,7 +1338,9 @@ class JMDaemonServerProtocol(amp.AMP, OrderbookWatch): """Retrieve the libsodium box object for the counterparty; stored differently for Taker and Maker """ - if nick in self.crypto_boxes and self.crypto_boxes[nick] != None: + if nick in self.frost_crypto_boxes: + return self.frost_crypto_boxes[nick] + elif nick in self.crypto_boxes and self.crypto_boxes[nick] != None: return self.crypto_boxes[nick][1] elif nick in self.active_orders and self.active_orders[nick] != None \ and "crypto_box" in self.active_orders[nick]: diff --git a/src/jmdaemon/message_channel.py b/src/jmdaemon/message_channel.py index c4f0e2a..11a2d8a 100644 --- a/src/jmdaemon/message_channel.py +++ b/src/jmdaemon/message_channel.py @@ -207,7 +207,7 @@ class MessageChannelCollection(object): #END PUBLIC/BROADCAST SECTION - def get_encryption_box(self, cmd, nick): + def get_encryption_box(self, cmd, nick, extra=None): """Establish whether the message is to be encrypted/decrypted based on the command string. If so, retrieve the appropriate crypto_box object @@ -215,12 +215,25 @@ class MessageChannelCollection(object): if cmd in plaintext_commands: return None, False else: - return self.daemon.get_crypto_box_from_nick(nick), True + if cmd in ['frostinit', 'frostround1', 'frostagg1', 'frostround2']: + boxes = self.daemon.get_crypto_box_from_nick(nick) + if boxes: + if extra: + box = boxes.get(extra)['crypto_box'] + else: + box = None + else: + box = self.daemon.get_crypto_box_from_nick(nick) + return box, True @check_privmsg def prepare_privmsg(self, nick, cmd, message, mc=None): # should we encrypt? - box, encrypt = self.get_encryption_box(cmd, nick) + if cmd in ['frostinit', 'frostround1', 'frostagg1', 'frostround2']: + session_id = message.split()[0] + box, encrypt = self.get_encryption_box(cmd, nick, extra=session_id) + else: + box, encrypt = self.get_encryption_box(cmd, nick) if encrypt: if not box: log.debug('error, dont have encryption box object for ' + nick + @@ -617,6 +630,8 @@ class MessageChannelCollection(object): on_dkgfinalized=None, on_dkgcmsg1=None, on_dkgcmsg2=None, + on_frostreq=None, + on_frostack=None, on_frostinit=None, on_frostround1=None, on_frostround2=None, @@ -626,8 +641,9 @@ class MessageChannelCollection(object): on_dkginit, on_dkgpmsg1, on_dkgpmsg2, on_dkgfinalized, on_dkgcmsg1, on_dkgcmsg2, - on_frostinit, - on_frostround1, on_frostround2, on_frostagg1) + on_frostreq, on_frostack, + on_frostinit, on_frostround1, + on_frostround2, on_frostagg1) def on_verified_privmsg(self, nick, message, hostid): """Called from daemon when message was successfully verified, @@ -692,6 +708,8 @@ class MessageChannel(object): self.on_dkgfinalized = None self.on_dkgcmsg1 = None self.on_dkgcmsg2 = None + self.on_frostreq = None + self.on_frostack = None self.on_frostinit = None self.on_frostround1 = None self.on_frostround2 = None @@ -810,6 +828,8 @@ class MessageChannel(object): on_dkgfinalized=None, on_dkgcmsg1=None, on_dkgcmsg2=None, + on_frostreq=None, + on_frostack=None, on_frostinit=None, on_frostround1=None, on_frostround2=None, @@ -820,6 +840,8 @@ class MessageChannel(object): self.on_dkgfinalized = on_dkgfinalized self.on_dkgcmsg1 = on_dkgcmsg1 self.on_dkgcmsg2 = on_dkgcmsg2 + self.on_frostreq = on_frostreq + self.on_frostack = on_frostack self.on_frostinit = on_frostinit self.on_frostround1 = on_frostround1 self.on_frostround2 = on_frostround2 @@ -952,16 +974,17 @@ class MessageChannel(object): except (ValueError, IndexError) as e: log.debug("!dkginit" + repr(e)) return - elif _chunks[0] == 'frostinit': + elif _chunks[0] == 'frostreq': try: hostpubkeyhash = _chunks[1] - session_id = _chunks[2] - sig = _chunks[3] - if self.on_frostinit: - self.on_frostinit(nick, hostpubkeyhash, - session_id, sig) + sig = _chunks[2] + session_id = _chunks[3] + dh_pubk = _chunks[4] + if self.on_frostreq: + self.on_frostreq( + nick, hostpubkeyhash, sig, session_id, dh_pubk) except (ValueError, IndexError) as e: - log.debug("!frostinit" + repr(e)) + log.debug("!frostreq" + repr(e)) return elif self.check_for_orders(nick, _chunks): pass @@ -1038,9 +1061,27 @@ class MessageChannel(object): _chunks = command.split(" ") #Decrypt if necessary - if _chunks[0] in encrypted_commands: - box, encrypt = self.daemon.mcc.get_encryption_box(_chunks[0], - nick) + cmd = _chunks[0] + if cmd in encrypted_commands: + if cmd in ['frostinit', 'frostround1', + 'frostagg1', 'frostround2']: + expected_msgs = self.daemon.frost_expected_msgs + if nick in expected_msgs: + if cmd in expected_msgs[nick]: + expected_msg = expected_msgs[nick].pop(cmd) + if not expected_msgs[nick]: + expected_msgs.pop(nick) + if expected_msg: + session_id = expected_msg['session_id'] + box, encrypt = \ + self.daemon.mcc.get_encryption_box( + cmd, nick, extra=session_id) + else: + box = None + encrypt = True + else: + box, encrypt = self.daemon.mcc.get_encryption_box( + cmd, nick) if encrypt: if not box: log.debug('error, dont have encryption box object for ' @@ -1161,14 +1202,25 @@ class MessageChannel(object): ext_recovery = _chunks[3] if self.on_dkgcmsg2: self.on_dkgcmsg2(nick, session_id, cmsg2, ext_recovery) - elif _chunks[0] == 'frostround1': + elif _chunks[0] == 'frostack': hostpubkeyhash = _chunks[1] - session_id = _chunks[2] - sig = _chunks[3] - pub_nonce = _chunks[4] + sig = _chunks[2] + session_id = _chunks[3] + dh_pubk = _chunks[4] + if self.on_frostack: + self.on_frostack( + nick, hostpubkeyhash, sig, session_id, dh_pubk) + elif _chunks[0] == 'frostinit': + session_id = _chunks[1] + if self.on_frostinit: + self.on_frostinit(nick, session_id) + elif _chunks[0] == 'frostround1': + session_id = _chunks[1] + pubkeyhash = _chunks[2] + pub_nonce = _chunks[3] if self.on_frostround1: self.on_frostround1( - nick, hostpubkeyhash, session_id, sig, pub_nonce) + nick, session_id, pubkeyhash, pub_nonce) elif _chunks[0] == 'frostagg1': session_id = _chunks[1] nonce_agg = _chunks[2] diff --git a/src/jmdaemon/protocol.py b/src/jmdaemon/protocol.py index 6c9afd8..72a67b6 100644 --- a/src/jmdaemon/protocol.py +++ b/src/jmdaemon/protocol.py @@ -40,17 +40,21 @@ COMMITMENT_PREFIXES = ["P"] dkg_public_list = ['dkginit'] dkg_private_list = ['dkgpmsg1', 'dkgpmsg2', 'dkgcmsg1', 'dkgcmsg2', 'dkgfinalized'] -frost_public_list = ['frostinit'] -frost_private_list = ['frostround1', 'frostround2', 'frostagg1'] + +frost_public_list = ['frostreq'] +frost_plaintext_list = frost_public_list + ['frostack'] +frost_encrypted_list = ['frostinit', 'frostround1', + 'frostround2', 'frostagg1'] + encrypted_commands = ["auth", "ioauth", "tx", "sig"] +encrypted_commands += frost_encrypted_list plaintext_commands = ["fill", "error", "pubkey", "orderbook", "push"] commitment_broadcast_list = ["hp2"] plaintext_commands += offername_list plaintext_commands += commitment_broadcast_list plaintext_commands += dkg_public_list plaintext_commands += dkg_private_list -plaintext_commands += frost_public_list -plaintext_commands += frost_private_list +plaintext_commands += frost_plaintext_list public_commands = commitment_broadcast_list + [ "orderbook", "cancel" ] + offername_list + [ dkg_public_list + frost_public_list] diff --git a/test/jmclient/test_frost_clients.py b/test/jmclient/test_frost_clients.py index 4463583..1d18b93 100644 --- a/test/jmclient/test_frost_clients.py +++ b/test/jmclient/test_frost_clients.py @@ -627,97 +627,163 @@ class FROSTClientTestCase(DKGClientTestCaseBase): self.fc3 = FROSTClient(self.wlt_svc3) self.fc4 = FROSTClient(self.wlt_svc4) - async def test_frost_init(self): + async def test_frost_req(self): msg_bytes = bytes.fromhex('aabb'*16) # test wallet with unknown hostpubkey - hostpubkeyhash_hex, session_id, sig_hex = self.fc4.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc4.frost_req( self.dkg_session_id, msg_bytes) assert hostpubkeyhash_hex is None - assert session_id is None assert sig_hex is None + assert session_id is None - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) assert hostpubkeyhash_hex and len(hostpubkeyhash_hex) == 64 - assert session_id and len(session_id) == 32 assert sig_hex and len(sig_hex) == 128 + assert session_id and len(session_id) == 32 - async def test_on_frost_init(self): + async def test_on_frost_req(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) # fail with wrong pubkeyhash hostpubkeyhash4_hex = sha256(self.hostpubkey4).digest() ( - nick1, + nick2, hostpubkeyhash2_hex, - session_id2_hex, sig2_hex, - pub_nonce - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash4_hex, session_id, sig_hex) - for v in [nick1, hostpubkeyhash2_hex, - session_id2_hex, sig2_hex, pub_nonce]: + session_id2_hex, + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash4_hex, sig_hex, session_id) + for v in [nick2, hostpubkeyhash2_hex, sig2_hex, session_id2_hex]: assert v is None # fail with wrong sig ( - nick1, + nick2, hostpubkeyhash2_hex, - session_id2_hex, sig2_hex, - pub_nonce - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, '01020304'*16) - for v in [nick1, hostpubkeyhash2_hex, - session_id2_hex, sig2_hex, pub_nonce]: + session_id2_hex, + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, '01020304'*16, session_id) + for v in [nick2, hostpubkeyhash2_hex, sig2_hex, session_id2_hex]: assert v is None ( - nick1, + nick2, hostpubkeyhash2_hex, + sig2_hex, session_id2_hex, + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + assert nick2 == self.nick1 + assert hostpubkeyhash2_hex and len(hostpubkeyhash2_hex) == 64 + assert sig_hex and len(sig_hex) == 128 + assert session_id2_hex and len(session_id2_hex) == 64 + assert bytes.fromhex(session_id2_hex) == session_id + + # fail on second call with right params + ( + nick2, + hostpubkeyhash2_hex, + sig2_hex, + session_id2_hex, + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + for v in [nick2, hostpubkeyhash2_hex, sig2_hex, session_id2_hex]: + assert v is None + + async def test_on_frost_ack(self): + msg_bytes = bytes.fromhex('aabb'*16) + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( + self.dkg_session_id, msg_bytes) + + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + # fail with wrong pubkeyhash + hostpubkeyhash4_hex = sha256(self.hostpubkey4).digest() + assert not self.fc1.on_frost_ack( + self.nick4, hostpubkeyhash4_hex, sig2_hex, session_id) + + # fail with wrong sig + hostpubkeyhash4_hex = sha256(self.hostpubkey4).digest() + assert not self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, '01020304'*16, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + + async def test_on_frost_init(self): + msg_bytes = bytes.fromhex('aabb'*16) + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( + self.dkg_session_id, msg_bytes) + + ( + nick2, + hostpubkeyhash2, sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + + ( + nick1, + session_id2_hex, + hostpubkeyhash2_hex, pub_nonce - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc2.on_frost_init(self.nick1, session_id) assert nick1 == self.nick1 - assert hostpubkeyhash2_hex and len(hostpubkeyhash2_hex) == 64 assert session_id2_hex and len(session_id2_hex) == 64 + assert hostpubkeyhash2_hex and len(hostpubkeyhash2_hex) == 64 assert bytes.fromhex(session_id2_hex) == session_id - assert sig_hex and len(sig_hex) == 128 assert pub_nonce and len(pub_nonce) == 66 # fail on second call with right params ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, - pub_nonce - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) - for v in [nick1, hostpubkeyhash2_hex, - session_id2_hex, sig2_hex, pub_nonce]: + hostpubkeyhash2_hex, + pub_nonce2 + ) = self.fc2.on_frost_init(self.nick1, session_id) + print('V'*80, nick1, session_id2_hex, hostpubkeyhash2_hex, pub_nonce2) + for v in [nick1, session_id2_hex, hostpubkeyhash2_hex, pub_nonce2]: assert v is None def test_frost_round1(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, - pub_nonce - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + hostpubkeyhash2_hex, + pub_nonce2 + ) = self.fc2.on_frost_init(self.nick1, session_id) # fail with unknown session_id - pub_nonce = self.fc2.party_step1(b'\x05'*32) + pub_nonce = self.fc2.frost_round1(b'\x05'*32) assert pub_nonce is None # fail with session.sec_nonce already set @@ -731,26 +797,33 @@ class FROSTClientTestCase(DKGClientTestCaseBase): def test_on_frost_round1(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, + hostpubkeyhash2_hex, pub_nonce2 - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc2.on_frost_init(self.nick1, session_id) ( nick1, - hostpubkeyhash3_hex, session_id3_hex, - sig3_hex, + hostpubkeyhash3_hex, pub_nonce3 - ) = self.fc3.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc3.on_frost_init(self.nick1, session_id) # unknown session_id ( @@ -760,8 +833,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, b'\xaa'*32, - sig2_hex, pub_nonce2) + self.nick2, b'\xaa'*32, hostpubkeyhash2_hex, pub_nonce2) for v in [ready_list, nonce_agg, dkg_session_id, ids, msg]: assert v is None @@ -773,19 +845,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, 'bb'*32, session_id, sig2_hex, pub_nonce2) - for v in [ready_list, nonce_agg, dkg_session_id, ids, msg]: - assert v is None - - # wrong sig - ( - ready_list, - nonce_agg, - dkg_session_id, - ids, - msg - ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, session_id, '1234'*32, pub_nonce2) + self.nick2, session_id, 'bb'*32, pub_nonce2) for v in [ready_list, nonce_agg, dkg_session_id, ids, msg]: assert v is None @@ -796,8 +856,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, session_id, - sig2_hex, pub_nonce2) + self.nick2, session_id, hostpubkeyhash2_hex, pub_nonce2) assert ready_list == set([self.nick2]) assert nonce_agg and len(nonce_agg)== 66 assert dkg_session_id and dkg_session_id == self.dkg_session_id @@ -812,23 +871,32 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick3, hostpubkeyhash3_hex, session_id, sig3_hex, pub_nonce3) + self.nick3, session_id, hostpubkeyhash3_hex, pub_nonce3) for v in [ready_list, nonce_agg, dkg_session_id, ids, msg]: assert v is None def test_frost_agg1(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, + hostpubkeyhash2_hex, pub_nonce2 - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc2.on_frost_init(self.nick1, session_id) ( ready_list, @@ -837,8 +905,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, session_id, - sig2_hex, pub_nonce2) + self.nick2, session_id, hostpubkeyhash2_hex, pub_nonce2) # fail on unknown session_id ( @@ -875,17 +942,26 @@ class FROSTClientTestCase(DKGClientTestCaseBase): def test_frost_round2(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, + hostpubkeyhash2_hex, pub_nonce2 - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc2.on_frost_init(self.nick1, session_id) ( ready_list, @@ -894,8 +970,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, session_id, - sig2_hex, pub_nonce2) + self.nick2, session_id, hostpubkeyhash2_hex, pub_nonce2) # fail on unknown session_id partial_sig = self.fc2.frost_round2( @@ -916,17 +991,26 @@ class FROSTClientTestCase(DKGClientTestCaseBase): def test_on_frost_round2(self): msg_bytes = bytes.fromhex('aabb'*16) - hostpubkeyhash_hex, session_id, sig_hex = self.fc1.frost_init( + hostpubkeyhash_hex, sig_hex, session_id = self.fc1.frost_req( self.dkg_session_id, msg_bytes) + ( + nick2, + hostpubkeyhash2, + sig2_hex, + session_id_hex + ) = self.fc2.on_frost_req( + self.nick1, hostpubkeyhash_hex, sig_hex, session_id) + + assert self.fc1.on_frost_ack( + self.nick2, hostpubkeyhash2, sig2_hex, session_id) + ( nick1, - hostpubkeyhash2_hex, session_id2_hex, - sig2_hex, + hostpubkeyhash2_hex, pub_nonce2 - ) = self.fc2.on_frost_init( - self.nick1, hostpubkeyhash_hex, session_id, sig_hex) + ) = self.fc2.on_frost_init(self.nick1, session_id) ( ready_list, @@ -935,8 +1019,7 @@ class FROSTClientTestCase(DKGClientTestCaseBase): ids, msg ) = self.fc1.on_frost_round1( - self.nick2, hostpubkeyhash2_hex, session_id, - sig2_hex, pub_nonce2) + self.nick2, session_id, hostpubkeyhash2_hex, pub_nonce2) partial_sig = self.fc2.frost_round2( session_id, nonce_agg, self.dkg_session_id, ids, msg) diff --git a/test/jmclient/test_frost_ipc.py b/test/jmclient/test_frost_ipc.py index 6ea5b0f..bc975a7 100644 --- a/test/jmclient/test_frost_ipc.py +++ b/test/jmclient/test_frost_ipc.py @@ -170,43 +170,62 @@ class DummyFrostJMClientProtocol: log.debug(f'Coordinator get dkgfinalized') client.on_dkg_finalized(nick, session_id) - def frost_init(self, dkg_session_id, msg_bytes): - log.debug(f'Coordinator call frost_init') + def frost_req(self, dkg_session_id, msg_bytes): + log.debug(f'Coordinator call frost_req') client = self.factory.client - hostpubkeyhash, session_id, sig = client.frost_init( + hostpubkeyhash, sig, session_id = client.frost_req( dkg_session_id, msg_bytes) coordinator = client.frost_coordinators.get(session_id) session = client.frost_sessions.get(session_id) if session_id and session and coordinator: - coordinator.frost_init_sec = time.time() + coordinator.frost_req_sec = time.time() for _, pc in self.party_clients.items(): - async def on_frost_init(pc, nick, hostpubkeyhash, - session_id, sig): - await pc.on_frost_init( - nick, hostpubkeyhash, session_id, sig) + async def on_frost_req(pc, nick, hostpubkeyhash, + sig, session_id): + await pc.on_frost_req( + nick, hostpubkeyhash, sig, session_id) - asyncio.create_task(on_frost_init( - pc, self.nick, hostpubkeyhash, session_id, sig)) + asyncio.create_task(on_frost_req( + pc, self.nick, hostpubkeyhash, sig, session_id)) return session_id, coordinator, session - async def on_frost_init(self, nick, hostpubkeyhash, session_id, sig): + async def on_frost_req(self, nick, hostpubkeyhash, sig, session_id): client = self.factory.client ( - nick, + nick2, hostpubkeyhash, - session_id, sig, + session_id, + ) = client.on_frost_req(nick, hostpubkeyhash, sig, session_id) + if sig: + pc = self.party_clients[nick] + session_id = bytes.fromhex(session_id) + await pc.on_frost_ack( + self.nick, hostpubkeyhash, sig, session_id) + + async def on_frost_ack(self, nick, hostpubkeyhash, sig, session_id): + client = self.factory.client + assert client.on_frost_ack(nick, hostpubkeyhash, sig, session_id) + pc = self.party_clients[nick] + await pc.on_frost_init(self.nick, session_id) + + async def on_frost_init(self, nick, session_id): + client = self.factory.client + ( + nick2, + session_id, + hostpubkeyhash, pub_nonce - ) = client.on_frost_init(nick, hostpubkeyhash, session_id, sig) + ) = client.on_frost_init(nick, session_id) if pub_nonce: pc = self.party_clients[nick] session_id = bytes.fromhex(session_id) await pc.on_frost_round1( - self.nick, hostpubkeyhash, session_id, sig, pub_nonce) + self.nick, session_id, hostpubkeyhash, pub_nonce) - async def on_frost_round1(self, nick, hostpubkeyhash, - session_id, sig, pub_nonce): + async def on_frost_round1(self, nick, session_id, + hostpubkeyhash, pub_nonce): client = self.factory.client ( ready_nicks, @@ -215,19 +234,14 @@ class DummyFrostJMClientProtocol: ids, msg ) = client.on_frost_round1( - nick, hostpubkeyhash, session_id, sig, pub_nonce) + nick, session_id, hostpubkeyhash, pub_nonce) if ready_nicks and nonce_agg: for party_nick in ready_nicks: pc = self.party_clients[nick] - self.frost_agg1(pc, self.nick, session_id, nonce_agg, - dkg_session_id, ids, msg) - - def frost_agg1(self, pc, nick, session_id, - nonce_agg, dkg_session_id, ids, msg): - pc.on_frost_agg1( - self.nick, session_id, nonce_agg, dkg_session_id, ids, msg) + await pc.on_frost_agg1( + self.nick, session_id, nonce_agg, dkg_session_id, ids, msg) - def on_frost_agg1(self, nick, session_id, + async def on_frost_agg1(self, nick, session_id, nonce_agg, dkg_session_id, ids, msg): client = self.factory.client session = client.frost_sessions.get(session_id) @@ -239,11 +253,11 @@ class DummyFrostJMClientProtocol: session_id, nonce_agg, dkg_session_id, ids, msg) if partial_sig: pc = self.party_clients[nick] - pc.on_frost_round2(self.nick, session_id, partial_sig) + await pc.on_frost_round2(self.nick, session_id, partial_sig) else: log.error(f'on_frost_agg1: not coordinator nick {nick}') - def on_frost_round2(self, nick, session_id, partial_sig): + async def on_frost_round2(self, nick, session_id, partial_sig): client = self.factory.client sig = client.on_frost_round2(nick, session_id, partial_sig) if sig: @@ -333,9 +347,10 @@ class FrostIPCClientTestCase(FrostIPCTestCaseBase): pubkeys = list(dkg._dkg_pubkey.values()) assert pubkey and pubkey in pubkeys - async def test_frost_sign(self): + async def test_frost_req(self): sighash = bytes.fromhex('01020304'*8) - sig, pubkey, tweaked_pubkey = await self.ipcc.frost_sign(0, 0, 0, sighash) + sig, pubkey, tweaked_pubkey = await self.ipcc.frost_req( + 0, 0, 0, sighash) assert sig and len(sig) == 64 assert pubkey and len(pubkey) == 33 assert tweaked_pubkey and len(tweaked_pubkey) == 33