You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

778 lines
26 KiB

  1. """Tests for asyncio/sslproto.py."""
  2. import logging
  3. import socket
  4. import sys
  5. from test import support
  6. import unittest
  7. import weakref
  8. from unittest import mock
  9. try:
  10. import ssl
  11. except ImportError:
  12. ssl = None
  13. import asyncio
  14. from asyncio import log
  15. from asyncio import protocols
  16. from asyncio import sslproto
  17. from test import support
  18. from test.test_asyncio import utils as test_utils
  19. from test.test_asyncio import functional as func_tests
  20. def tearDownModule():
  21. asyncio.set_event_loop_policy(None)
  22. @unittest.skipIf(ssl is None, 'No ssl module')
  23. class SslProtoHandshakeTests(test_utils.TestCase):
  24. def setUp(self):
  25. super().setUp()
  26. self.loop = asyncio.new_event_loop()
  27. self.set_event_loop(self.loop)
  28. def ssl_protocol(self, *, waiter=None, proto=None):
  29. sslcontext = test_utils.dummy_ssl_context()
  30. if proto is None: # app protocol
  31. proto = asyncio.Protocol()
  32. ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
  33. ssl_handshake_timeout=0.1)
  34. self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
  35. self.addCleanup(ssl_proto._app_transport.close)
  36. return ssl_proto
  37. def connection_made(self, ssl_proto, *, do_handshake=None):
  38. transport = mock.Mock()
  39. sslpipe = mock.Mock()
  40. sslpipe.shutdown.return_value = b''
  41. if do_handshake:
  42. sslpipe.do_handshake.side_effect = do_handshake
  43. else:
  44. def mock_handshake(callback):
  45. return []
  46. sslpipe.do_handshake.side_effect = mock_handshake
  47. with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
  48. ssl_proto.connection_made(transport)
  49. return transport
  50. def test_handshake_timeout_zero(self):
  51. sslcontext = test_utils.dummy_ssl_context()
  52. app_proto = mock.Mock()
  53. waiter = mock.Mock()
  54. with self.assertRaisesRegex(ValueError, 'a positive number'):
  55. sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
  56. ssl_handshake_timeout=0)
  57. def test_handshake_timeout_negative(self):
  58. sslcontext = test_utils.dummy_ssl_context()
  59. app_proto = mock.Mock()
  60. waiter = mock.Mock()
  61. with self.assertRaisesRegex(ValueError, 'a positive number'):
  62. sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
  63. ssl_handshake_timeout=-10)
  64. def test_eof_received_waiter(self):
  65. waiter = self.loop.create_future()
  66. ssl_proto = self.ssl_protocol(waiter=waiter)
  67. self.connection_made(ssl_proto)
  68. ssl_proto.eof_received()
  69. test_utils.run_briefly(self.loop)
  70. self.assertIsInstance(waiter.exception(), ConnectionResetError)
  71. def test_fatal_error_no_name_error(self):
  72. # From issue #363.
  73. # _fatal_error() generates a NameError if sslproto.py
  74. # does not import base_events.
  75. waiter = self.loop.create_future()
  76. ssl_proto = self.ssl_protocol(waiter=waiter)
  77. # Temporarily turn off error logging so as not to spoil test output.
  78. log_level = log.logger.getEffectiveLevel()
  79. log.logger.setLevel(logging.FATAL)
  80. try:
  81. ssl_proto._fatal_error(None)
  82. finally:
  83. # Restore error logging.
  84. log.logger.setLevel(log_level)
  85. def test_connection_lost(self):
  86. # From issue #472.
  87. # yield from waiter hang if lost_connection was called.
  88. waiter = self.loop.create_future()
  89. ssl_proto = self.ssl_protocol(waiter=waiter)
  90. self.connection_made(ssl_proto)
  91. ssl_proto.connection_lost(ConnectionAbortedError)
  92. test_utils.run_briefly(self.loop)
  93. self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
  94. def test_close_during_handshake(self):
  95. # bpo-29743 Closing transport during handshake process leaks socket
  96. waiter = self.loop.create_future()
  97. ssl_proto = self.ssl_protocol(waiter=waiter)
  98. transport = self.connection_made(ssl_proto)
  99. test_utils.run_briefly(self.loop)
  100. ssl_proto._app_transport.close()
  101. self.assertTrue(transport.abort.called)
  102. def test_get_extra_info_on_closed_connection(self):
  103. waiter = self.loop.create_future()
  104. ssl_proto = self.ssl_protocol(waiter=waiter)
  105. self.assertIsNone(ssl_proto._get_extra_info('socket'))
  106. default = object()
  107. self.assertIs(ssl_proto._get_extra_info('socket', default), default)
  108. self.connection_made(ssl_proto)
  109. self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
  110. ssl_proto.connection_lost(None)
  111. self.assertIsNone(ssl_proto._get_extra_info('socket'))
  112. def test_set_new_app_protocol(self):
  113. waiter = self.loop.create_future()
  114. ssl_proto = self.ssl_protocol(waiter=waiter)
  115. new_app_proto = asyncio.Protocol()
  116. ssl_proto._app_transport.set_protocol(new_app_proto)
  117. self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
  118. self.assertIs(ssl_proto._app_protocol, new_app_proto)
  119. def test_data_received_after_closing(self):
  120. ssl_proto = self.ssl_protocol()
  121. self.connection_made(ssl_proto)
  122. transp = ssl_proto._app_transport
  123. transp.close()
  124. # should not raise
  125. self.assertIsNone(ssl_proto.data_received(b'data'))
  126. def test_write_after_closing(self):
  127. ssl_proto = self.ssl_protocol()
  128. self.connection_made(ssl_proto)
  129. transp = ssl_proto._app_transport
  130. transp.close()
  131. # should not raise
  132. self.assertIsNone(transp.write(b'data'))
  133. ##############################################################################
  134. # Start TLS Tests
  135. ##############################################################################
  136. class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
  137. PAYLOAD_SIZE = 1024 * 100
  138. TIMEOUT = support.LONG_TIMEOUT
  139. def new_loop(self):
  140. raise NotImplementedError
  141. def test_buf_feed_data(self):
  142. class Proto(asyncio.BufferedProtocol):
  143. def __init__(self, bufsize, usemv):
  144. self.buf = bytearray(bufsize)
  145. self.mv = memoryview(self.buf)
  146. self.data = b''
  147. self.usemv = usemv
  148. def get_buffer(self, sizehint):
  149. if self.usemv:
  150. return self.mv
  151. else:
  152. return self.buf
  153. def buffer_updated(self, nsize):
  154. if self.usemv:
  155. self.data += self.mv[:nsize]
  156. else:
  157. self.data += self.buf[:nsize]
  158. for usemv in [False, True]:
  159. proto = Proto(1, usemv)
  160. protocols._feed_data_to_buffered_proto(proto, b'12345')
  161. self.assertEqual(proto.data, b'12345')
  162. proto = Proto(2, usemv)
  163. protocols._feed_data_to_buffered_proto(proto, b'12345')
  164. self.assertEqual(proto.data, b'12345')
  165. proto = Proto(2, usemv)
  166. protocols._feed_data_to_buffered_proto(proto, b'1234')
  167. self.assertEqual(proto.data, b'1234')
  168. proto = Proto(4, usemv)
  169. protocols._feed_data_to_buffered_proto(proto, b'1234')
  170. self.assertEqual(proto.data, b'1234')
  171. proto = Proto(100, usemv)
  172. protocols._feed_data_to_buffered_proto(proto, b'12345')
  173. self.assertEqual(proto.data, b'12345')
  174. proto = Proto(0, usemv)
  175. with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
  176. protocols._feed_data_to_buffered_proto(proto, b'12345')
  177. def test_start_tls_client_reg_proto_1(self):
  178. HELLO_MSG = b'1' * self.PAYLOAD_SIZE
  179. server_context = test_utils.simple_server_sslcontext()
  180. client_context = test_utils.simple_client_sslcontext()
  181. def serve(sock):
  182. sock.settimeout(self.TIMEOUT)
  183. data = sock.recv_all(len(HELLO_MSG))
  184. self.assertEqual(len(data), len(HELLO_MSG))
  185. sock.start_tls(server_context, server_side=True)
  186. sock.sendall(b'O')
  187. data = sock.recv_all(len(HELLO_MSG))
  188. self.assertEqual(len(data), len(HELLO_MSG))
  189. sock.shutdown(socket.SHUT_RDWR)
  190. sock.close()
  191. class ClientProto(asyncio.Protocol):
  192. def __init__(self, on_data, on_eof):
  193. self.on_data = on_data
  194. self.on_eof = on_eof
  195. self.con_made_cnt = 0
  196. def connection_made(proto, tr):
  197. proto.con_made_cnt += 1
  198. # Ensure connection_made gets called only once.
  199. self.assertEqual(proto.con_made_cnt, 1)
  200. def data_received(self, data):
  201. self.on_data.set_result(data)
  202. def eof_received(self):
  203. self.on_eof.set_result(True)
  204. async def client(addr):
  205. await asyncio.sleep(0.5)
  206. on_data = self.loop.create_future()
  207. on_eof = self.loop.create_future()
  208. tr, proto = await self.loop.create_connection(
  209. lambda: ClientProto(on_data, on_eof), *addr)
  210. tr.write(HELLO_MSG)
  211. new_tr = await self.loop.start_tls(tr, proto, client_context)
  212. self.assertEqual(await on_data, b'O')
  213. new_tr.write(HELLO_MSG)
  214. await on_eof
  215. new_tr.close()
  216. with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
  217. self.loop.run_until_complete(
  218. asyncio.wait_for(client(srv.addr),
  219. timeout=support.SHORT_TIMEOUT))
  220. # No garbage is left if SSL is closed uncleanly
  221. client_context = weakref.ref(client_context)
  222. self.assertIsNone(client_context())
  223. def test_create_connection_memory_leak(self):
  224. HELLO_MSG = b'1' * self.PAYLOAD_SIZE
  225. server_context = test_utils.simple_server_sslcontext()
  226. client_context = test_utils.simple_client_sslcontext()
  227. def serve(sock):
  228. sock.settimeout(self.TIMEOUT)
  229. sock.start_tls(server_context, server_side=True)
  230. sock.sendall(b'O')
  231. data = sock.recv_all(len(HELLO_MSG))
  232. self.assertEqual(len(data), len(HELLO_MSG))
  233. sock.shutdown(socket.SHUT_RDWR)
  234. sock.close()
  235. class ClientProto(asyncio.Protocol):
  236. def __init__(self, on_data, on_eof):
  237. self.on_data = on_data
  238. self.on_eof = on_eof
  239. self.con_made_cnt = 0
  240. def connection_made(proto, tr):
  241. # XXX: We assume user stores the transport in protocol
  242. proto.tr = tr
  243. proto.con_made_cnt += 1
  244. # Ensure connection_made gets called only once.
  245. self.assertEqual(proto.con_made_cnt, 1)
  246. def data_received(self, data):
  247. self.on_data.set_result(data)
  248. def eof_received(self):
  249. self.on_eof.set_result(True)
  250. async def client(addr):
  251. await asyncio.sleep(0.5)
  252. on_data = self.loop.create_future()
  253. on_eof = self.loop.create_future()
  254. tr, proto = await self.loop.create_connection(
  255. lambda: ClientProto(on_data, on_eof), *addr,
  256. ssl=client_context)
  257. self.assertEqual(await on_data, b'O')
  258. tr.write(HELLO_MSG)
  259. await on_eof
  260. tr.close()
  261. with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
  262. self.loop.run_until_complete(
  263. asyncio.wait_for(client(srv.addr),
  264. timeout=support.SHORT_TIMEOUT))
  265. # No garbage is left for SSL client from loop.create_connection, even
  266. # if user stores the SSLTransport in corresponding protocol instance
  267. client_context = weakref.ref(client_context)
  268. self.assertIsNone(client_context())
  269. def test_start_tls_client_buf_proto_1(self):
  270. HELLO_MSG = b'1' * self.PAYLOAD_SIZE
  271. server_context = test_utils.simple_server_sslcontext()
  272. client_context = test_utils.simple_client_sslcontext()
  273. client_con_made_calls = 0
  274. def serve(sock):
  275. sock.settimeout(self.TIMEOUT)
  276. data = sock.recv_all(len(HELLO_MSG))
  277. self.assertEqual(len(data), len(HELLO_MSG))
  278. sock.start_tls(server_context, server_side=True)
  279. sock.sendall(b'O')
  280. data = sock.recv_all(len(HELLO_MSG))
  281. self.assertEqual(len(data), len(HELLO_MSG))
  282. sock.sendall(b'2')
  283. data = sock.recv_all(len(HELLO_MSG))
  284. self.assertEqual(len(data), len(HELLO_MSG))
  285. sock.shutdown(socket.SHUT_RDWR)
  286. sock.close()
  287. class ClientProtoFirst(asyncio.BufferedProtocol):
  288. def __init__(self, on_data):
  289. self.on_data = on_data
  290. self.buf = bytearray(1)
  291. def connection_made(self, tr):
  292. nonlocal client_con_made_calls
  293. client_con_made_calls += 1
  294. def get_buffer(self, sizehint):
  295. return self.buf
  296. def buffer_updated(self, nsize):
  297. assert nsize == 1
  298. self.on_data.set_result(bytes(self.buf[:nsize]))
  299. class ClientProtoSecond(asyncio.Protocol):
  300. def __init__(self, on_data, on_eof):
  301. self.on_data = on_data
  302. self.on_eof = on_eof
  303. self.con_made_cnt = 0
  304. def connection_made(self, tr):
  305. nonlocal client_con_made_calls
  306. client_con_made_calls += 1
  307. def data_received(self, data):
  308. self.on_data.set_result(data)
  309. def eof_received(self):
  310. self.on_eof.set_result(True)
  311. async def client(addr):
  312. await asyncio.sleep(0.5)
  313. on_data1 = self.loop.create_future()
  314. on_data2 = self.loop.create_future()
  315. on_eof = self.loop.create_future()
  316. tr, proto = await self.loop.create_connection(
  317. lambda: ClientProtoFirst(on_data1), *addr)
  318. tr.write(HELLO_MSG)
  319. new_tr = await self.loop.start_tls(tr, proto, client_context)
  320. self.assertEqual(await on_data1, b'O')
  321. new_tr.write(HELLO_MSG)
  322. new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
  323. self.assertEqual(await on_data2, b'2')
  324. new_tr.write(HELLO_MSG)
  325. await on_eof
  326. new_tr.close()
  327. # connection_made() should be called only once -- when
  328. # we establish connection for the first time. Start TLS
  329. # doesn't call connection_made() on application protocols.
  330. self.assertEqual(client_con_made_calls, 1)
  331. with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
  332. self.loop.run_until_complete(
  333. asyncio.wait_for(client(srv.addr),
  334. timeout=self.TIMEOUT))
  335. def test_start_tls_slow_client_cancel(self):
  336. HELLO_MSG = b'1' * self.PAYLOAD_SIZE
  337. client_context = test_utils.simple_client_sslcontext()
  338. server_waits_on_handshake = self.loop.create_future()
  339. def serve(sock):
  340. sock.settimeout(self.TIMEOUT)
  341. data = sock.recv_all(len(HELLO_MSG))
  342. self.assertEqual(len(data), len(HELLO_MSG))
  343. try:
  344. self.loop.call_soon_threadsafe(
  345. server_waits_on_handshake.set_result, None)
  346. data = sock.recv_all(1024 * 1024)
  347. except ConnectionAbortedError:
  348. pass
  349. finally:
  350. sock.close()
  351. class ClientProto(asyncio.Protocol):
  352. def __init__(self, on_data, on_eof):
  353. self.on_data = on_data
  354. self.on_eof = on_eof
  355. self.con_made_cnt = 0
  356. def connection_made(proto, tr):
  357. proto.con_made_cnt += 1
  358. # Ensure connection_made gets called only once.
  359. self.assertEqual(proto.con_made_cnt, 1)
  360. def data_received(self, data):
  361. self.on_data.set_result(data)
  362. def eof_received(self):
  363. self.on_eof.set_result(True)
  364. async def client(addr):
  365. await asyncio.sleep(0.5)
  366. on_data = self.loop.create_future()
  367. on_eof = self.loop.create_future()
  368. tr, proto = await self.loop.create_connection(
  369. lambda: ClientProto(on_data, on_eof), *addr)
  370. tr.write(HELLO_MSG)
  371. await server_waits_on_handshake
  372. with self.assertRaises(asyncio.TimeoutError):
  373. await asyncio.wait_for(
  374. self.loop.start_tls(tr, proto, client_context),
  375. 0.5)
  376. with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
  377. self.loop.run_until_complete(
  378. asyncio.wait_for(client(srv.addr),
  379. timeout=support.SHORT_TIMEOUT))
  380. def test_start_tls_server_1(self):
  381. HELLO_MSG = b'1' * self.PAYLOAD_SIZE
  382. ANSWER = b'answer'
  383. server_context = test_utils.simple_server_sslcontext()
  384. client_context = test_utils.simple_client_sslcontext()
  385. answer = None
  386. def client(sock, addr):
  387. nonlocal answer
  388. sock.settimeout(self.TIMEOUT)
  389. sock.connect(addr)
  390. data = sock.recv_all(len(HELLO_MSG))
  391. self.assertEqual(len(data), len(HELLO_MSG))
  392. sock.start_tls(client_context)
  393. sock.sendall(HELLO_MSG)
  394. answer = sock.recv_all(len(ANSWER))
  395. sock.close()
  396. class ServerProto(asyncio.Protocol):
  397. def __init__(self, on_con, on_con_lost, on_got_hello):
  398. self.on_con = on_con
  399. self.on_con_lost = on_con_lost
  400. self.on_got_hello = on_got_hello
  401. self.data = b''
  402. self.transport = None
  403. def connection_made(self, tr):
  404. self.transport = tr
  405. self.on_con.set_result(tr)
  406. def replace_transport(self, tr):
  407. self.transport = tr
  408. def data_received(self, data):
  409. self.data += data
  410. if len(self.data) >= len(HELLO_MSG):
  411. self.on_got_hello.set_result(None)
  412. def connection_lost(self, exc):
  413. self.transport = None
  414. if exc is None:
  415. self.on_con_lost.set_result(None)
  416. else:
  417. self.on_con_lost.set_exception(exc)
  418. async def main(proto, on_con, on_con_lost, on_got_hello):
  419. tr = await on_con
  420. tr.write(HELLO_MSG)
  421. self.assertEqual(proto.data, b'')
  422. new_tr = await self.loop.start_tls(
  423. tr, proto, server_context,
  424. server_side=True,
  425. ssl_handshake_timeout=self.TIMEOUT)
  426. proto.replace_transport(new_tr)
  427. await on_got_hello
  428. new_tr.write(ANSWER)
  429. await on_con_lost
  430. self.assertEqual(proto.data, HELLO_MSG)
  431. new_tr.close()
  432. async def run_main():
  433. on_con = self.loop.create_future()
  434. on_con_lost = self.loop.create_future()
  435. on_got_hello = self.loop.create_future()
  436. proto = ServerProto(on_con, on_con_lost, on_got_hello)
  437. server = await self.loop.create_server(
  438. lambda: proto, '127.0.0.1', 0)
  439. addr = server.sockets[0].getsockname()
  440. with self.tcp_client(lambda sock: client(sock, addr),
  441. timeout=self.TIMEOUT):
  442. await asyncio.wait_for(
  443. main(proto, on_con, on_con_lost, on_got_hello),
  444. timeout=self.TIMEOUT)
  445. server.close()
  446. await server.wait_closed()
  447. self.assertEqual(answer, ANSWER)
  448. self.loop.run_until_complete(run_main())
  449. def test_start_tls_wrong_args(self):
  450. async def main():
  451. with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
  452. await self.loop.start_tls(None, None, None)
  453. sslctx = test_utils.simple_server_sslcontext()
  454. with self.assertRaisesRegex(TypeError, 'is not supported'):
  455. await self.loop.start_tls(None, None, sslctx)
  456. self.loop.run_until_complete(main())
  457. def test_handshake_timeout(self):
  458. # bpo-29970: Check that a connection is aborted if handshake is not
  459. # completed in timeout period, instead of remaining open indefinitely
  460. client_sslctx = test_utils.simple_client_sslcontext()
  461. messages = []
  462. self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
  463. server_side_aborted = False
  464. def server(sock):
  465. nonlocal server_side_aborted
  466. try:
  467. sock.recv_all(1024 * 1024)
  468. except ConnectionAbortedError:
  469. server_side_aborted = True
  470. finally:
  471. sock.close()
  472. async def client(addr):
  473. await asyncio.wait_for(
  474. self.loop.create_connection(
  475. asyncio.Protocol,
  476. *addr,
  477. ssl=client_sslctx,
  478. server_hostname='',
  479. ssl_handshake_timeout=support.SHORT_TIMEOUT),
  480. 0.5)
  481. with self.tcp_server(server,
  482. max_clients=1,
  483. backlog=1) as srv:
  484. with self.assertRaises(asyncio.TimeoutError):
  485. self.loop.run_until_complete(client(srv.addr))
  486. self.assertTrue(server_side_aborted)
  487. # Python issue #23197: cancelling a handshake must not raise an
  488. # exception or log an error, even if the handshake failed
  489. self.assertEqual(messages, [])
  490. # The 10s handshake timeout should be cancelled to free related
  491. # objects without really waiting for 10s
  492. client_sslctx = weakref.ref(client_sslctx)
  493. self.assertIsNone(client_sslctx())
  494. def test_create_connection_ssl_slow_handshake(self):
  495. client_sslctx = test_utils.simple_client_sslcontext()
  496. messages = []
  497. self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
  498. def server(sock):
  499. try:
  500. sock.recv_all(1024 * 1024)
  501. except ConnectionAbortedError:
  502. pass
  503. finally:
  504. sock.close()
  505. async def client(addr):
  506. with self.assertWarns(DeprecationWarning):
  507. reader, writer = await asyncio.open_connection(
  508. *addr,
  509. ssl=client_sslctx,
  510. server_hostname='',
  511. loop=self.loop,
  512. ssl_handshake_timeout=1.0)
  513. with self.tcp_server(server,
  514. max_clients=1,
  515. backlog=1) as srv:
  516. with self.assertRaisesRegex(
  517. ConnectionAbortedError,
  518. r'SSL handshake.*is taking longer'):
  519. self.loop.run_until_complete(client(srv.addr))
  520. self.assertEqual(messages, [])
  521. def test_create_connection_ssl_failed_certificate(self):
  522. self.loop.set_exception_handler(lambda loop, ctx: None)
  523. sslctx = test_utils.simple_server_sslcontext()
  524. client_sslctx = test_utils.simple_client_sslcontext(
  525. disable_verify=False)
  526. def server(sock):
  527. try:
  528. sock.start_tls(
  529. sslctx,
  530. server_side=True)
  531. except ssl.SSLError:
  532. pass
  533. except OSError:
  534. pass
  535. finally:
  536. sock.close()
  537. async def client(addr):
  538. with self.assertWarns(DeprecationWarning):
  539. reader, writer = await asyncio.open_connection(
  540. *addr,
  541. ssl=client_sslctx,
  542. server_hostname='',
  543. loop=self.loop,
  544. ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
  545. with self.tcp_server(server,
  546. max_clients=1,
  547. backlog=1) as srv:
  548. with self.assertRaises(ssl.SSLCertVerificationError):
  549. self.loop.run_until_complete(client(srv.addr))
  550. def test_start_tls_client_corrupted_ssl(self):
  551. self.loop.set_exception_handler(lambda loop, ctx: None)
  552. sslctx = test_utils.simple_server_sslcontext()
  553. client_sslctx = test_utils.simple_client_sslcontext()
  554. def server(sock):
  555. orig_sock = sock.dup()
  556. try:
  557. sock.start_tls(
  558. sslctx,
  559. server_side=True)
  560. sock.sendall(b'A\n')
  561. sock.recv_all(1)
  562. orig_sock.send(b'please corrupt the SSL connection')
  563. except ssl.SSLError:
  564. pass
  565. finally:
  566. orig_sock.close()
  567. sock.close()
  568. async def client(addr):
  569. with self.assertWarns(DeprecationWarning):
  570. reader, writer = await asyncio.open_connection(
  571. *addr,
  572. ssl=client_sslctx,
  573. server_hostname='',
  574. loop=self.loop)
  575. self.assertEqual(await reader.readline(), b'A\n')
  576. writer.write(b'B')
  577. with self.assertRaises(ssl.SSLError):
  578. await reader.readline()
  579. writer.close()
  580. return 'OK'
  581. with self.tcp_server(server,
  582. max_clients=1,
  583. backlog=1) as srv:
  584. res = self.loop.run_until_complete(client(srv.addr))
  585. self.assertEqual(res, 'OK')
  586. @unittest.skipIf(ssl is None, 'No ssl module')
  587. class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
  588. def new_loop(self):
  589. return asyncio.SelectorEventLoop()
  590. @unittest.skipIf(ssl is None, 'No ssl module')
  591. @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
  592. class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
  593. def new_loop(self):
  594. return asyncio.ProactorEventLoop()
  595. if __name__ == '__main__':
  596. unittest.main()