|
|
|
@ -2,6 +2,8 @@ |
|
|
|
|
|
|
|
import errno |
|
|
|
import socket |
|
|
|
import threading |
|
|
|
import time |
|
|
|
import unittest |
|
|
|
from unittest import mock |
|
|
|
try: |
|
|
|
@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): |
|
|
|
(10, self.loop._sock_sendall, f, True, sock, b'data'), |
|
|
|
self.loop.add_writer.call_args[0]) |
|
|
|
|
|
|
|
def test_sock_connect(self): |
|
|
|
sock = test_utils.mock_nonblocking_socket() |
|
|
|
self.loop._sock_connect = mock.Mock() |
|
|
|
|
|
|
|
f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) |
|
|
|
self.assertIsInstance(f, asyncio.Future) |
|
|
|
self.loop._run_once() |
|
|
|
future_in, sock_in, address_in = self.loop._sock_connect.call_args[0] |
|
|
|
self.assertEqual(future_in, f) |
|
|
|
self.assertEqual(sock_in, sock) |
|
|
|
self.assertEqual(address_in, ('127.0.0.1', 8080)) |
|
|
|
|
|
|
|
def test_sock_connect_timeout(self): |
|
|
|
# asyncio issue #205: sock_connect() must unregister the socket on |
|
|
|
# timeout error |
|
|
|
@ -360,29 +350,34 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): |
|
|
|
sock.connect.side_effect = BlockingIOError |
|
|
|
|
|
|
|
# first call to sock_connect() registers the socket |
|
|
|
fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) |
|
|
|
fut = self.loop.create_task( |
|
|
|
self.loop.sock_connect(sock, ('127.0.0.1', 80))) |
|
|
|
self.loop._run_once() |
|
|
|
self.assertTrue(sock.connect.called) |
|
|
|
self.assertTrue(self.loop.add_writer.called) |
|
|
|
self.assertEqual(len(fut._callbacks), 1) |
|
|
|
|
|
|
|
# on timeout, the socket must be unregistered |
|
|
|
sock.connect.reset_mock() |
|
|
|
fut.set_exception(asyncio.TimeoutError) |
|
|
|
with self.assertRaises(asyncio.TimeoutError): |
|
|
|
fut.cancel() |
|
|
|
with self.assertRaises(asyncio.CancelledError): |
|
|
|
self.loop.run_until_complete(fut) |
|
|
|
self.assertTrue(self.loop.remove_writer.called) |
|
|
|
|
|
|
|
def test_sock_connect_resolve_using_socket_params(self): |
|
|
|
@mock.patch('socket.getaddrinfo') |
|
|
|
def test_sock_connect_resolve_using_socket_params(self, m_gai): |
|
|
|
addr = ('need-resolution.com', 8080) |
|
|
|
sock = test_utils.mock_nonblocking_socket() |
|
|
|
self.loop.getaddrinfo = mock.Mock() |
|
|
|
self.loop.sock_connect(sock, addr) |
|
|
|
while not self.loop.getaddrinfo.called: |
|
|
|
m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0)) |
|
|
|
m_gai._is_coroutine = False |
|
|
|
con = self.loop.create_task(self.loop.sock_connect(sock, addr)) |
|
|
|
while not m_gai.called: |
|
|
|
self.loop._run_once() |
|
|
|
self.loop.getaddrinfo.assert_called_with( |
|
|
|
*addr, type=sock.type, family=sock.family, proto=sock.proto, |
|
|
|
flags=0) |
|
|
|
m_gai.assert_called_with( |
|
|
|
addr[0], addr[1], sock.family, sock.type, sock.proto, 0) |
|
|
|
|
|
|
|
con.cancel() |
|
|
|
with self.assertRaises(asyncio.CancelledError): |
|
|
|
self.loop.run_until_complete(con) |
|
|
|
|
|
|
|
def test__sock_connect(self): |
|
|
|
f = asyncio.Future(loop=self.loop) |
|
|
|
@ -1792,5 +1787,88 @@ class SelectorDatagramTransportTests(test_utils.TestCase): |
|
|
|
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) |
|
|
|
|
|
|
|
|
|
|
|
class SelectorLoopFunctionalTests(unittest.TestCase): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.loop = asyncio.new_event_loop() |
|
|
|
asyncio.set_event_loop(None) |
|
|
|
|
|
|
|
def tearDown(self): |
|
|
|
self.loop.close() |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def recv_all(self, sock, nbytes): |
|
|
|
buf = b'' |
|
|
|
while len(buf) < nbytes: |
|
|
|
buf += yield from self.loop.sock_recv(sock, nbytes - len(buf)) |
|
|
|
return buf |
|
|
|
|
|
|
|
def test_sock_connect_sock_write_race(self): |
|
|
|
TIMEOUT = 3.0 |
|
|
|
PAYLOAD = b'DATA' * 1024 * 1024 |
|
|
|
|
|
|
|
class Server(threading.Thread): |
|
|
|
def __init__(self, *args, srv_sock, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.srv_sock = srv_sock |
|
|
|
|
|
|
|
def run(self): |
|
|
|
with self.srv_sock: |
|
|
|
srv_sock.listen(100) |
|
|
|
|
|
|
|
sock, addr = self.srv_sock.accept() |
|
|
|
sock.settimeout(TIMEOUT) |
|
|
|
|
|
|
|
with sock: |
|
|
|
sock.sendall(b'helo') |
|
|
|
|
|
|
|
buf = bytearray() |
|
|
|
while len(buf) < len(PAYLOAD): |
|
|
|
pack = sock.recv(1024 * 65) |
|
|
|
if not pack: |
|
|
|
break |
|
|
|
buf.extend(pack) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def client(addr): |
|
|
|
sock = socket.socket() |
|
|
|
with sock: |
|
|
|
sock.setblocking(False) |
|
|
|
|
|
|
|
started = time.monotonic() |
|
|
|
while True: |
|
|
|
if time.monotonic() - started > TIMEOUT: |
|
|
|
self.fail('unable to connect to the socket') |
|
|
|
return |
|
|
|
try: |
|
|
|
yield from self.loop.sock_connect(sock, addr) |
|
|
|
except OSError: |
|
|
|
yield from asyncio.sleep(0.05, loop=self.loop) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
# Give 'Server' thread a chance to accept and send b'helo' |
|
|
|
time.sleep(0.1) |
|
|
|
|
|
|
|
data = yield from self.recv_all(sock, 4) |
|
|
|
self.assertEqual(data, b'helo') |
|
|
|
yield from self.loop.sock_sendall(sock, PAYLOAD) |
|
|
|
|
|
|
|
srv_sock = socket.socket() |
|
|
|
srv_sock.settimeout(TIMEOUT) |
|
|
|
srv_sock.bind(('127.0.0.1', 0)) |
|
|
|
srv_addr = srv_sock.getsockname() |
|
|
|
|
|
|
|
srv = Server(srv_sock=srv_sock, daemon=True) |
|
|
|
srv.start() |
|
|
|
|
|
|
|
try: |
|
|
|
self.loop.run_until_complete( |
|
|
|
asyncio.wait_for(client(srv_addr), loop=self.loop, |
|
|
|
timeout=TIMEOUT)) |
|
|
|
finally: |
|
|
|
srv.join() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
unittest.main() |