Browse Source

fix test_websocket.py sporadic fails

add_frost_channel_encryption
zebra-lucky 2 months ago
parent
commit
ee05fcbc73
  1. 46
      test/jmclient/test_websocket.py

46
test/jmclient/test_websocket.py

@ -2,6 +2,7 @@
import os import os
import json import json
from twisted.internet import reactor, task from twisted.internet import reactor, task
from twisted.internet.defer import inlineCallbacks, Deferred
from twisted.trial import unittest from twisted.trial import unittest
from autobahn.twisted.websocket import WebSocketClientFactory, \ from autobahn.twisted.websocket import WebSocketClientFactory, \
WebSocketClientProtocol, connectWS, listenWS WebSocketClientProtocol, connectWS, listenWS
@ -24,6 +25,13 @@ test_tx_hex_txid = "ca606efc5ba8f6669ba15e9262e5d38e745345ea96106d5a919688d1ff0d
token_authority = JMTokenAuthority() token_authority = JMTokenAuthority()
class ServerTProtocol(JmwalletdWebSocketServerProtocol):
def onMessage(self, payload, isBinary):
super().onMessage(payload, isBinary)
self.factory.on_message_d.callback(None)
class ClientTProtocol(WebSocketClientProtocol): class ClientTProtocol(WebSocketClientProtocol):
""" """
Simple client that connects to a WebSocket server, send a HELLO Simple client that connects to a WebSocket server, send a HELLO
@ -48,6 +56,7 @@ class ClientTProtocol(WebSocketClientProtocol):
payload = payload.decode("utf-8") payload = payload.decode("utf-8")
jlog.info("Text message received: {}".format(payload)) jlog.info("Text message received: {}".format(payload))
self.factory.notifs += 1 self.factory.notifs += 1
self.factory.on_message_d.callback(None)
# ensure we got the transaction message expected: # ensure we got the transaction message expected:
deser_notif = json.loads(payload) deser_notif = json.loads(payload)
assert deser_notif["txid"] == test_tx_hex_txid assert deser_notif["txid"] == test_tx_hex_txid
@ -67,36 +76,36 @@ class WebsocketTestBase(object):
self.wss_port = free_ports[0] self.wss_port = free_ports[0]
self.wss_url = "ws://127.0.0.1:" + str(self.wss_port) self.wss_url = "ws://127.0.0.1:" + str(self.wss_port)
self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, token_authority) self.wss_factory = JmwalletdWebSocketServerFactory(self.wss_url, token_authority)
self.wss_factory.protocol = JmwalletdWebSocketServerProtocol self.wss_factory.protocol = ServerTProtocol
self.wss_factory.on_message_d = Deferred()
self.listeningport = listenWS(self.wss_factory, contextFactory=None) self.listeningport = listenWS(self.wss_factory, contextFactory=None)
self.test_tx = CTransaction.deserialize(hextobin(test_tx_hex_1)) self.test_tx = CTransaction.deserialize(hextobin(test_tx_hex_1))
def stopListening(self): def stopListening(self):
return self.listeningport.stopListening() return self.listeningport.stopListening()
@inlineCallbacks
def do_test(self): def do_test(self):
self.client_factory = WebSocketClientFactory("ws://127.0.0.1:"+str(self.wss_port)) self.client_factory = WebSocketClientFactory("ws://127.0.0.1:"+str(self.wss_port))
self.client_factory.on_message_d = Deferred()
self.client_factory.protocol = ClientTProtocol self.client_factory.protocol = ClientTProtocol
# keep track of the connector object so we can close it manually: # keep track of the connector object so we can close it manually:
self.client_connector = connectWS(self.client_factory) self.client_connector = connectWS(self.client_factory)
d = task.deferLater(reactor, 0.1, self.fire_tx_notif)
# create a small delay between the instruction to send # wait on server to receive message
# the notification, and the checking of its receipt, yield self.wss_factory.on_message_d
# otherwise the client will be queried before the notification
# arrived: yield self.fire_tx_notif()
d.addCallback(self.wait_to_receive) # wait on client to receive message
return d yield self.client_factory.on_message_d
def wait_to_receive(self, res):
d = task.deferLater(reactor, 0.1, self.checkNotifs)
return d
def checkNotifs(self):
assert self.client_factory.notifs == 1 assert self.client_factory.notifs == 1
@inlineCallbacks
def fire_tx_notif(self): def fire_tx_notif(self):
self.wss_factory.sendTxNotification(self.test_tx, yield self.wss_factory.sendTxNotification(
test_tx_hex_txid) self.test_tx, test_tx_hex_txid)
def tearDown(self): def tearDown(self):
reactor.disconnectAll() reactor.disconnectAll()
@ -106,6 +115,9 @@ class WebsocketTestBase(object):
self.client_connector.disconnect() self.client_connector.disconnect()
return self.stopListening() return self.stopListening()
class TrialTestWS(WebsocketTestBase, unittest.TestCase): class TrialTestWS(WebsocketTestBase, unittest.TestCase):
@inlineCallbacks
def test_basic_notification(self): def test_basic_notification(self):
return self.do_test() yield self.do_test()

Loading…
Cancel
Save