You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

473 lines
16 KiB

  1. # Wrapper module for _ssl, providing some additional facilities
  2. # implemented in Python. Written by Bill Janssen.
  3. """\
  4. This module provides some more Pythonic support for SSL.
  5. Object types:
  6. SSLSocket -- subtype of socket.socket which does SSL over the socket
  7. Exceptions:
  8. SSLError -- exception raised for I/O errors
  9. Functions:
  10. cert_time_to_seconds -- convert time string used for certificate
  11. notBefore and notAfter functions to integer
  12. seconds past the Epoch (the time values
  13. returned from time.time())
  14. fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
  15. by the server running on HOST at port PORT. No
  16. validation of the certificate is performed.
  17. Integer constants:
  18. SSL_ERROR_ZERO_RETURN
  19. SSL_ERROR_WANT_READ
  20. SSL_ERROR_WANT_WRITE
  21. SSL_ERROR_WANT_X509_LOOKUP
  22. SSL_ERROR_SYSCALL
  23. SSL_ERROR_SSL
  24. SSL_ERROR_WANT_CONNECT
  25. SSL_ERROR_EOF
  26. SSL_ERROR_INVALID_ERROR_CODE
  27. The following group define certificate requirements that one side is
  28. allowing/requiring from the other side:
  29. CERT_NONE - no certificates from the other side are required (or will
  30. be looked at if provided)
  31. CERT_OPTIONAL - certificates are not required, but if provided will be
  32. validated, and if validation fails, the connection will
  33. also fail
  34. CERT_REQUIRED - certificates are required, and will be validated, and
  35. if validation fails, the connection will also fail
  36. The following constants identify various SSL protocol variants:
  37. PROTOCOL_SSLv2
  38. PROTOCOL_SSLv3
  39. PROTOCOL_SSLv23
  40. PROTOCOL_TLSv1
  41. """
  42. import textwrap
  43. import _ssl # if we can't import it, let the error propagate
  44. from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
  45. from _ssl import SSLError
  46. from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
  47. from _ssl import RAND_status, RAND_egd, RAND_add
  48. from _ssl import \
  49. SSL_ERROR_ZERO_RETURN, \
  50. SSL_ERROR_WANT_READ, \
  51. SSL_ERROR_WANT_WRITE, \
  52. SSL_ERROR_WANT_X509_LOOKUP, \
  53. SSL_ERROR_SYSCALL, \
  54. SSL_ERROR_SSL, \
  55. SSL_ERROR_WANT_CONNECT, \
  56. SSL_ERROR_EOF, \
  57. SSL_ERROR_INVALID_ERROR_CODE
  58. from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
  59. _PROTOCOL_NAMES = {
  60. PROTOCOL_TLSv1: "TLSv1",
  61. PROTOCOL_SSLv23: "SSLv23",
  62. PROTOCOL_SSLv3: "SSLv3",
  63. }
  64. try:
  65. from _ssl import PROTOCOL_SSLv2
  66. _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
  67. except ImportError:
  68. _SSLv2_IF_EXISTS = None
  69. else:
  70. _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
  71. from socket import socket, _fileobject, _delegate_methods, error as socket_error
  72. from socket import getnameinfo as _getnameinfo
  73. import base64 # for DER-to-PEM translation
  74. import errno
  75. # Disable weak or insecure ciphers by default
  76. # (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
  77. _DEFAULT_CIPHERS = 'DEFAULT:!aNULL:!eNULL:!LOW:!EXPORT:!SSLv2'
  78. class SSLSocket(socket):
  79. """This class implements a subtype of socket.socket that wraps
  80. the underlying OS socket in an SSL context when necessary, and
  81. provides read and write methods over that channel."""
  82. def __init__(self, sock, keyfile=None, certfile=None,
  83. server_side=False, cert_reqs=CERT_NONE,
  84. ssl_version=PROTOCOL_SSLv23, ca_certs=None,
  85. do_handshake_on_connect=True,
  86. suppress_ragged_eofs=True, ciphers=None):
  87. socket.__init__(self, _sock=sock._sock)
  88. # The initializer for socket overrides the methods send(), recv(), etc.
  89. # in the instancce, which we don't need -- but we want to provide the
  90. # methods defined in SSLSocket.
  91. for attr in _delegate_methods:
  92. try:
  93. delattr(self, attr)
  94. except AttributeError:
  95. pass
  96. if ciphers is None and ssl_version != _SSLv2_IF_EXISTS:
  97. ciphers = _DEFAULT_CIPHERS
  98. if certfile and not keyfile:
  99. keyfile = certfile
  100. # see if it's connected
  101. try:
  102. socket.getpeername(self)
  103. except socket_error, e:
  104. if e.errno != errno.ENOTCONN:
  105. raise
  106. # no, no connection yet
  107. self._connected = False
  108. self._sslobj = None
  109. else:
  110. # yes, create the SSL object
  111. self._connected = True
  112. self._sslobj = _ssl.sslwrap(self._sock, server_side,
  113. keyfile, certfile,
  114. cert_reqs, ssl_version, ca_certs,
  115. ciphers)
  116. if do_handshake_on_connect:
  117. self.do_handshake()
  118. self.keyfile = keyfile
  119. self.certfile = certfile
  120. self.cert_reqs = cert_reqs
  121. self.ssl_version = ssl_version
  122. self.ca_certs = ca_certs
  123. self.ciphers = ciphers
  124. self.do_handshake_on_connect = do_handshake_on_connect
  125. self.suppress_ragged_eofs = suppress_ragged_eofs
  126. self._makefile_refs = 0
  127. def read(self, len=1024):
  128. """Read up to LEN bytes and return them.
  129. Return zero-length string on EOF."""
  130. try:
  131. return self._sslobj.read(len)
  132. except SSLError, x:
  133. if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
  134. return ''
  135. else:
  136. raise
  137. def write(self, data):
  138. """Write DATA to the underlying SSL channel. Returns
  139. number of bytes of DATA actually transmitted."""
  140. return self._sslobj.write(data)
  141. def getpeercert(self, binary_form=False):
  142. """Returns a formatted version of the data in the
  143. certificate provided by the other end of the SSL channel.
  144. Return None if no certificate was provided, {} if a
  145. certificate was provided, but not validated."""
  146. return self._sslobj.peer_certificate(binary_form)
  147. def cipher(self):
  148. if not self._sslobj:
  149. return None
  150. else:
  151. return self._sslobj.cipher()
  152. def send(self, data, flags=0):
  153. if self._sslobj:
  154. if flags != 0:
  155. raise ValueError(
  156. "non-zero flags not allowed in calls to send() on %s" %
  157. self.__class__)
  158. while True:
  159. try:
  160. v = self._sslobj.write(data)
  161. except SSLError, x:
  162. if x.args[0] == SSL_ERROR_WANT_READ:
  163. return 0
  164. elif x.args[0] == SSL_ERROR_WANT_WRITE:
  165. return 0
  166. else:
  167. raise
  168. else:
  169. return v
  170. else:
  171. return self._sock.send(data, flags)
  172. def sendto(self, data, flags_or_addr, addr=None):
  173. if self._sslobj:
  174. raise ValueError("sendto not allowed on instances of %s" %
  175. self.__class__)
  176. elif addr is None:
  177. return self._sock.sendto(data, flags_or_addr)
  178. else:
  179. return self._sock.sendto(data, flags_or_addr, addr)
  180. def sendall(self, data, flags=0):
  181. if self._sslobj:
  182. if flags != 0:
  183. raise ValueError(
  184. "non-zero flags not allowed in calls to sendall() on %s" %
  185. self.__class__)
  186. amount = len(data)
  187. count = 0
  188. while (count < amount):
  189. v = self.send(data[count:])
  190. count += v
  191. return amount
  192. else:
  193. return socket.sendall(self, data, flags)
  194. def recv(self, buflen=1024, flags=0):
  195. if self._sslobj:
  196. if flags != 0:
  197. raise ValueError(
  198. "non-zero flags not allowed in calls to recv() on %s" %
  199. self.__class__)
  200. return self.read(buflen)
  201. else:
  202. return self._sock.recv(buflen, flags)
  203. def recv_into(self, buffer, nbytes=None, flags=0):
  204. if buffer and (nbytes is None):
  205. nbytes = len(buffer)
  206. elif nbytes is None:
  207. nbytes = 1024
  208. if self._sslobj:
  209. if flags != 0:
  210. raise ValueError(
  211. "non-zero flags not allowed in calls to recv_into() on %s" %
  212. self.__class__)
  213. tmp_buffer = self.read(nbytes)
  214. v = len(tmp_buffer)
  215. buffer[:v] = tmp_buffer
  216. return v
  217. else:
  218. return self._sock.recv_into(buffer, nbytes, flags)
  219. def recvfrom(self, buflen=1024, flags=0):
  220. if self._sslobj:
  221. raise ValueError("recvfrom not allowed on instances of %s" %
  222. self.__class__)
  223. else:
  224. return self._sock.recvfrom(buflen, flags)
  225. def recvfrom_into(self, buffer, nbytes=None, flags=0):
  226. if self._sslobj:
  227. raise ValueError("recvfrom_into not allowed on instances of %s" %
  228. self.__class__)
  229. else:
  230. return self._sock.recvfrom_into(buffer, nbytes, flags)
  231. def pending(self):
  232. if self._sslobj:
  233. return self._sslobj.pending()
  234. else:
  235. return 0
  236. def unwrap(self):
  237. if self._sslobj:
  238. s = self._sslobj.shutdown()
  239. self._sslobj = None
  240. return s
  241. else:
  242. raise ValueError("No SSL wrapper around " + str(self))
  243. def shutdown(self, how):
  244. self._sslobj = None
  245. socket.shutdown(self, how)
  246. def close(self):
  247. if self._makefile_refs < 1:
  248. self._sslobj = None
  249. socket.close(self)
  250. else:
  251. self._makefile_refs -= 1
  252. def do_handshake(self):
  253. """Perform a TLS/SSL handshake."""
  254. self._sslobj.do_handshake()
  255. def _real_connect(self, addr, return_errno):
  256. # Here we assume that the socket is client-side, and not
  257. # connected at the time of the call. We connect it, then wrap it.
  258. if self._connected:
  259. raise ValueError("attempt to connect already-connected SSLSocket!")
  260. self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
  261. self.cert_reqs, self.ssl_version,
  262. self.ca_certs, self.ciphers)
  263. try:
  264. socket.connect(self, addr)
  265. if self.do_handshake_on_connect:
  266. self.do_handshake()
  267. except socket_error as e:
  268. if return_errno:
  269. return e.errno
  270. else:
  271. self._sslobj = None
  272. raise e
  273. self._connected = True
  274. return 0
  275. def connect(self, addr):
  276. """Connects to remote ADDR, and then wraps the connection in
  277. an SSL channel."""
  278. self._real_connect(addr, False)
  279. def connect_ex(self, addr):
  280. """Connects to remote ADDR, and then wraps the connection in
  281. an SSL channel."""
  282. return self._real_connect(addr, True)
  283. def accept(self):
  284. """Accepts a new connection from a remote client, and returns
  285. a tuple containing that new connection wrapped with a server-side
  286. SSL channel, and the address of the remote client."""
  287. newsock, addr = socket.accept(self)
  288. return (SSLSocket(newsock,
  289. keyfile=self.keyfile,
  290. certfile=self.certfile,
  291. server_side=True,
  292. cert_reqs=self.cert_reqs,
  293. ssl_version=self.ssl_version,
  294. ca_certs=self.ca_certs,
  295. ciphers=self.ciphers,
  296. do_handshake_on_connect=self.do_handshake_on_connect,
  297. suppress_ragged_eofs=self.suppress_ragged_eofs),
  298. addr)
  299. def makefile(self, mode='r', bufsize=-1):
  300. """Make and return a file-like object that
  301. works with the SSL connection. Just use the code
  302. from the socket module."""
  303. self._makefile_refs += 1
  304. # close=True so as to decrement the reference count when done with
  305. # the file-like object.
  306. return _fileobject(self, mode, bufsize, close=True)
  307. def wrap_socket(sock, keyfile=None, certfile=None,
  308. server_side=False, cert_reqs=CERT_NONE,
  309. ssl_version=PROTOCOL_SSLv23, ca_certs=None,
  310. do_handshake_on_connect=True,
  311. suppress_ragged_eofs=True, ciphers=None):
  312. return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
  313. server_side=server_side, cert_reqs=cert_reqs,
  314. ssl_version=ssl_version, ca_certs=ca_certs,
  315. do_handshake_on_connect=do_handshake_on_connect,
  316. suppress_ragged_eofs=suppress_ragged_eofs,
  317. ciphers=ciphers)
  318. # some utility functions
  319. def cert_time_to_seconds(cert_time):
  320. """Takes a date-time string in standard ASN1_print form
  321. ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
  322. a Python time value in seconds past the epoch."""
  323. import time
  324. return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
  325. PEM_HEADER = "-----BEGIN CERTIFICATE-----"
  326. PEM_FOOTER = "-----END CERTIFICATE-----"
  327. def DER_cert_to_PEM_cert(der_cert_bytes):
  328. """Takes a certificate in binary DER format and returns the
  329. PEM version of it as a string."""
  330. if hasattr(base64, 'standard_b64encode'):
  331. # preferred because older API gets line-length wrong
  332. f = base64.standard_b64encode(der_cert_bytes)
  333. return (PEM_HEADER + '\n' +
  334. textwrap.fill(f, 64) + '\n' +
  335. PEM_FOOTER + '\n')
  336. else:
  337. return (PEM_HEADER + '\n' +
  338. base64.encodestring(der_cert_bytes) +
  339. PEM_FOOTER + '\n')
  340. def PEM_cert_to_DER_cert(pem_cert_string):
  341. """Takes a certificate in ASCII PEM format and returns the
  342. DER-encoded version of it as a byte sequence"""
  343. if not pem_cert_string.startswith(PEM_HEADER):
  344. raise ValueError("Invalid PEM encoding; must start with %s"
  345. % PEM_HEADER)
  346. if not pem_cert_string.strip().endswith(PEM_FOOTER):
  347. raise ValueError("Invalid PEM encoding; must end with %s"
  348. % PEM_FOOTER)
  349. d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
  350. return base64.decodestring(d)
  351. def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
  352. """Retrieve the certificate from the server at the specified address,
  353. and return it as a PEM-encoded string.
  354. If 'ca_certs' is specified, validate the server cert against it.
  355. If 'ssl_version' is specified, use it in the connection attempt."""
  356. host, port = addr
  357. if (ca_certs is not None):
  358. cert_reqs = CERT_REQUIRED
  359. else:
  360. cert_reqs = CERT_NONE
  361. s = wrap_socket(socket(), ssl_version=ssl_version,
  362. cert_reqs=cert_reqs, ca_certs=ca_certs)
  363. s.connect(addr)
  364. dercert = s.getpeercert(True)
  365. s.close()
  366. return DER_cert_to_PEM_cert(dercert)
  367. def get_protocol_name(protocol_code):
  368. return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
  369. # a replacement for the old socket.ssl function
  370. def sslwrap_simple(sock, keyfile=None, certfile=None):
  371. """A replacement for the old socket.ssl function. Designed
  372. for compability with Python 2.5 and earlier. Will disappear in
  373. Python 3.0."""
  374. if hasattr(sock, "_sock"):
  375. sock = sock._sock
  376. ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE,
  377. PROTOCOL_SSLv23, None)
  378. try:
  379. sock.getpeername()
  380. except socket_error:
  381. # no, no connection yet
  382. pass
  383. else:
  384. # yes, do the handshake
  385. ssl_sock.do_handshake()
  386. return ssl_sock