Browse Source

aiorpcx: pin certificates

master
Janus 7 years ago committed by SomberNight
parent
commit
89a01a6463
No known key found for this signature in database
GPG Key ID: B33B5F232C6271E9
  1. 2
      electrum/__init__.py
  2. 388
      electrum/interface.py
  3. 37
      electrum/network.py

2
electrum/__init__.py

@ -4,7 +4,7 @@ from .wallet import Wallet
from .storage import WalletStorage from .storage import WalletStorage
from .coinchooser import COIN_CHOOSERS from .coinchooser import COIN_CHOOSERS
from .network import Network, pick_random_server from .network import Network, pick_random_server
from .interface import Connection, Interface from .interface import Interface
from .simple_config import SimpleConfig, get_config, set_config from .simple_config import SimpleConfig, get_config, set_config
from . import bitcoin from . import bitcoin
from . import transaction from . import transaction

388
electrum/interface.py

@ -28,343 +28,131 @@ import socket
import ssl import ssl
import sys import sys
import threading import threading
import time
import traceback import traceback
import aiorpcx
import asyncio
import requests import requests
from .util import print_error from .util import PrintError
ca_path = requests.certs.where() ca_path = requests.certs.where()
from . import util from . import util
from . import x509 from . import x509
from . import pem from . import pem
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
class Interface(PrintError):
def Connection(server, queue, config_path): def __init__(self, server, config_path, connecting):
"""Makes asynchronous connections to a remote Electrum server. self.connecting = connecting
Returns the running thread that is making the connection.
Once the thread has connected, it finishes, placing a tuple on the
queue of the form (server, socket), where socket is None if
connection failed.
"""
host, port, protocol = server.rsplit(':', 2)
if not protocol in 'st':
raise Exception('Unknown protocol: %s' % protocol)
c = TcpConnection(server, queue, config_path)
c.start()
return c
class TcpConnection(threading.Thread, util.PrintError):
verbosity_filter = 'i'
def __init__(self, server, queue, config_path):
threading.Thread.__init__(self)
self.config_path = config_path
self.queue = queue
self.server = server self.server = server
self.host, self.port, self.protocol = self.server.rsplit(':', 2) self.host, self.port, self.protocol = self.server.split(':')
self.host = str(self.host) self.config_path = config_path
self.port = int(self.port) self.cert_path = os.path.join(self.config_path, 'certs', self.host)
self.use_ssl = (self.protocol == 's') self.fut = asyncio.get_event_loop().create_task(self.run())
self.daemon = True
def diagnostic_name(self): def diagnostic_name(self):
return self.host return self.host
def check_host_name(self, peercert, name): async def is_server_ca_signed(self, sslc):
"""Simple certificate/host name checker. Returns True if the try:
certificate matches, False otherwise. Does not support await self.open_session(sslc, do_sleep=False)
wildcards.""" except ssl.SSLError as e:
# Check that the peer has supplied a certificate. assert e.reason == 'CERTIFICATE_VERIFY_FAILED'
# None/{} is not acceptable.
if not peercert:
return False return False
if 'subjectAltName' in peercert:
for typ, val in peercert["subjectAltName"]:
if typ == "DNS" and val == name:
return True return True
else:
# Only check the subject DN if there is no subject alternative
# name.
cn = None
for attr, val in peercert["subject"]:
# Use most-specific (last) commonName attribute.
if attr == "commonName":
cn = val
if cn is not None:
return cn == name
return False
def get_simple_socket(self): @util.aiosafe
try: async def run(self):
l = socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM) if self.protocol != 's':
except socket.gaierror: await self.open_session(None, execute_after_connect=lambda: self.connecting.remove(self.server))
self.print_error("cannot resolve hostname")
return return
e = None
for res in l:
try:
s = socket.socket(res[0], socket.SOCK_STREAM)
s.settimeout(10)
s.connect(res[4])
s.settimeout(2)
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
return s
except BaseException as _e:
e = _e
continue
else:
self.print_error("failed to connect", str(e))
@staticmethod
def get_ssl_context(cert_reqs, ca_certs):
context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_certs)
context.check_hostname = False
context.verify_mode = cert_reqs
context.options |= ssl.OP_NO_SSLv2
context.options |= ssl.OP_NO_SSLv3
context.options |= ssl.OP_NO_TLSv1
return context ca_sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
exists = os.path.exists(self.cert_path)
def get_socket(self): if exists:
if self.use_ssl: with open(self.cert_path, 'r') as f:
cert_path = os.path.join(self.config_path, 'certs', self.host) contents = f.read()
if not os.path.exists(cert_path): if contents != '': # if not CA signed
is_new = True
s = self.get_simple_socket()
if s is None:
return
# try with CA first
try: try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED, ca_certs=ca_path) b = pem.dePem(contents, 'CERTIFICATE')
s = context.wrap_socket(s, do_handshake_on_connect=True) except SyntaxError:
except ssl.SSLError as e: exists = False
self.print_error(e)
except:
return
else: else:
x = x509.X509(b)
try: try:
peer_cert = s.getpeercert() x.check_date()
except OSError: except x509.CertificateError:
return self.print_error("certificate has expired:", self.cert_path)
if self.check_host_name(peer_cert, self.host): os.unlink(self.cert_path)
self.print_error("SSL certificate signed by CA") exists = False
return s if not exists:
# get server certificate. ca_signed = await self.is_server_ca_signed(ca_sslc)
# Do not use ssl.get_server_certificate because it does not work with proxy if ca_signed:
s = self.get_simple_socket() with open(self.cert_path, 'w') as f:
if s is None: # empty file means this is CA signed, not self-signed
return f.write('')
try: else:
context = self.get_ssl_context(cert_reqs=ssl.CERT_NONE, ca_certs=None) await self.save_certificate()
s = context.wrap_socket(s) siz = os.stat(self.cert_path).st_size
except ssl.SSLError as e: if siz == 0: # if CA signed
self.print_error("SSL error retrieving SSL certificate:", e) sslc = ca_sslc
return else:
except: sslc = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=self.cert_path)
return sslc.check_hostname = 0
await self.open_session(sslc, execute_after_connect=lambda: self.connecting.remove(self.server))
try:
dercert = s.getpeercert(True) async def save_certificate(self):
except OSError: if not os.path.exists(self.cert_path):
return # we may need to retry this a few times, in case the handshake hasn't completed
s.close() for _ in range(10):
dercert = await self.get_certificate()
if dercert:
self.print_error("succeeded in getting cert")
with open(self.cert_path, 'w') as f:
cert = ssl.DER_cert_to_PEM_cert(dercert) cert = ssl.DER_cert_to_PEM_cert(dercert)
# workaround android bug # workaround android bug
cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert) cert = re.sub("([^\n])-----END CERTIFICATE-----","\\1\n-----END CERTIFICATE-----",cert)
temporary_path = cert_path + '.temp'
util.assert_datadir_available(self.config_path)
with open(temporary_path, "w", encoding='utf-8') as f:
f.write(cert) f.write(cert)
# even though close flushes we can't fsync when closed.
# and we must flush before fsyncing, cause flush flushes to OS buffer
# fsync writes to OS buffer to disk
f.flush() f.flush()
os.fsync(f.fileno()) os.fsync(f.fileno())
else: break
is_new = False await asyncio.sleep(1)
assert False, "could not get certificate"
s = self.get_simple_socket()
if s is None:
return
if self.use_ssl:
try:
context = self.get_ssl_context(cert_reqs=ssl.CERT_REQUIRED,
ca_certs=(temporary_path if is_new else cert_path))
s = context.wrap_socket(s, do_handshake_on_connect=True)
except socket.timeout:
self.print_error('timeout')
return
except ssl.SSLError as e:
self.print_error("SSL error:", e)
if e.errno != 1:
return
if is_new:
rej = cert_path + '.rej'
if os.path.exists(rej):
os.unlink(rej)
os.rename(temporary_path, rej)
else:
util.assert_datadir_available(self.config_path)
with open(cert_path, encoding='utf-8') as f:
cert = f.read()
try:
b = pem.dePem(cert, 'CERTIFICATE')
x = x509.X509(b)
except:
traceback.print_exc(file=sys.stderr)
self.print_error("wrong certificate")
return
try:
x.check_date()
except:
self.print_error("certificate has expired:", cert_path)
os.unlink(cert_path)
return
self.print_error("wrong certificate")
if e.errno == 104:
return
return
except BaseException as e:
self.print_error(e)
traceback.print_exc(file=sys.stderr)
return
if is_new:
self.print_error("saving certificate")
os.rename(temporary_path, cert_path)
return s
def run(self):
socket = self.get_socket()
if socket:
self.print_error("connected")
self.queue.put((self.server, socket))
class Interface(util.PrintError):
"""The Interface class handles a socket connected to a single remote
Electrum server. Its exposed API is:
- Member functions close(), fileno(), get_responses(), has_timed_out(),
ping_required(), queue_request(), send_requests()
- Member variable server.
"""
def __init__(self, server, socket):
self.server = server
self.host, _, _ = server.rsplit(':', 2)
self.socket = socket
self.pipe = util.SocketPipe(socket)
self.pipe.set_timeout(0.0) # Don't wait for data
# Dump network messages. Set at runtime from the console.
self.debug = False
self.unsent_requests = []
self.unanswered_requests = {}
self.last_send = time.time()
self.closed_remotely = False
def diagnostic_name(self):
return self.host
def fileno(self):
# Needed for select
return self.socket.fileno()
def close(self):
if not self.closed_remotely:
try:
self.socket.shutdown(socket.SHUT_RDWR)
except socket.error:
pass
self.socket.close()
def queue_request(self, *args): # method, params, _id
'''Queue a request, later to be send with send_requests when the
socket is available for writing.
'''
self.request_time = time.time()
self.unsent_requests.append(args)
def num_requests(self):
'''Keep unanswered requests below 100'''
n = 100 - len(self.unanswered_requests)
return min(n, len(self.unsent_requests))
def send_requests(self): async def get_certificate(self):
'''Sends queued requests. Returns False on failure.''' sslc = ssl.SSLContext()
self.last_send = time.time()
make_dict = lambda m, p, i: {'method': m, 'params': p, 'id': i}
n = self.num_requests()
wire_requests = self.unsent_requests[0:n]
try: try:
self.pipe.send_all([make_dict(*r) for r in wire_requests]) async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
except BaseException as e: return session.transport._ssl_protocol._sslpipe._sslobj.getpeercert(True)
self.print_error("pipe send error:", e) except ValueError:
return False return None
self.unsent_requests = self.unsent_requests[n:]
for request in wire_requests: async def open_session(self, sslc, do_sleep=True, execute_after_connect=lambda: None):
if self.debug: async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
self.print_error("-->", request) ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
self.unanswered_requests[request[2]] = request print(ver)
return True connect_hook_executed = False
while do_sleep:
def ping_required(self): if not connect_hook_executed:
'''Returns True if a ping should be sent.''' connect_hook_executed = True
return time.time() - self.last_send > 300 execute_after_connect()
await asyncio.wait_for(session.send_request('server.ping'), 5)
await asyncio.sleep(300)
def has_timed_out(self): def has_timed_out(self):
'''Returns True if the interface has timed out.''' return self.fut.done()
if (self.unanswered_requests and time.time() - self.request_time > 10
and self.pipe.idle_time() > 10):
self.print_error("timeout", len(self.unanswered_requests))
return True
return False def queue_request(self, method, params, msg_id):
pass
def get_responses(self):
'''Call if there is data available on the socket. Returns a list of
(request, response) pairs. Notifications are singleton
unsolicited responses presumably as a result of prior
subscriptions, so request is None and there is no 'id' member.
Otherwise it is a response, which has an 'id' member and a
corresponding request. If the connection was closed remotely
or the remote server is misbehaving, a (None, None) will appear.
'''
responses = []
while True:
try:
response = self.pipe.get()
except util.timeout:
break
if not type(response) is dict:
responses.append((None, None))
if response is None:
self.closed_remotely = True
self.print_error("connection closed remotely")
break
if self.debug:
self.print_error("<--", response)
wire_id = response.get('id', None)
if wire_id is None: # Notification
responses.append((None, response))
else:
request = self.unanswered_requests.pop(wire_id, None)
if request:
responses.append((request, response))
else:
self.print_error("unknown wire ID", wire_id)
responses.append((None, None)) # Signal
break
return responses
def close(self):
self.fut.cancel()
def check_cert(host, cert): def check_cert(host, cert):
try: try:

37
electrum/network.py

@ -47,38 +47,14 @@ from . import blockchain
from .version import ELECTRUM_VERSION, PROTOCOL_VERSION from .version import ELECTRUM_VERSION, PROTOCOL_VERSION
from .i18n import _ from .i18n import _
from .blockchain import InvalidHeader from .blockchain import InvalidHeader
from .interface import Interface
import aiorpcx, asyncio, ssl import asyncio
import concurrent.futures import concurrent.futures
NODES_RETRY_INTERVAL = 60 NODES_RETRY_INTERVAL = 60
SERVER_RETRY_INTERVAL = 10 SERVER_RETRY_INTERVAL = 10
class Interface(PrintError):
@util.aiosafe
async def run(self):
self.host, self.port, self.protocol = self.server.split(':')
sslc = ssl.SSLContext(ssl.PROTOCOL_TLS) if self.protocol == 's' else None
async with aiorpcx.ClientSession(self.host, self.port, ssl=sslc) as session:
ver = await session.send_request('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])
print(ver)
while True:
print("sleeping")
await asyncio.sleep(1)
def __init__(self, server):
self.exception = None
self.server = server
self.fut = asyncio.get_event_loop().create_task(self.run())
def has_timed_out(self):
return self.fut.done()
def queue_request(self, method, params, msg_id):
pass
def close(self):
self.fut.cancel()
def parse_servers(result): def parse_servers(result):
""" parse servers list into dict format""" """ parse servers list into dict format"""
@ -539,7 +515,7 @@ class Network(PrintError):
self.close_interface(self.interface) self.close_interface(self.interface)
assert self.interface is None assert self.interface is None
assert not self.interfaces assert not self.interfaces
self.connecting = set() self.connecting.clear()
# Get a new queue - no old pending connections thanks! # Get a new queue - no old pending connections thanks!
self.socket_queue = queue.Queue() self.socket_queue = queue.Queue()
@ -810,7 +786,7 @@ class Network(PrintError):
def new_interface(self, server): def new_interface(self, server):
# todo: get tip first, then decide which checkpoint to use. # todo: get tip first, then decide which checkpoint to use.
self.add_recent_server(server) self.add_recent_server(server)
interface = Interface(server) interface = Interface(server, self.config.path, self.connecting)
interface.blockchain = None interface.blockchain = None
interface.tip_header = None interface.tip_header = None
interface.tip = 0 interface.tip = 0
@ -1368,9 +1344,12 @@ class Network(PrintError):
for k, i in self.interfaces.items(): for k, i in self.interfaces.items():
if i.has_timed_out(): if i.has_timed_out():
remove.append(k) remove.append(k)
changed = False
for k in remove: for k in remove:
self.connection_down(k) self.connection_down(k)
changed = True
for i in range(self.num_server - len(self.interfaces)): for i in range(self.num_server - len(self.interfaces)):
self.start_random_interface() self.start_random_interface()
self.notify('updated') changed = True
if changed: self.notify('updated')
await asyncio.sleep(1) await asyncio.sleep(1)

Loading…
Cancel
Save