1
# Copyright (C) 2010, 2011, 2016 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22
import SocketServer as socketserver
30
from breezy.tests import test_server
31
from breezy.tests.scenarios import load_tests_apply_scenarios
34
load_tests = load_tests_apply_scenarios
37
def portable_socket_pair():
38
"""Return a pair of TCP sockets connected to each other.
40
Unlike socket.socketpair, this should work on Windows.
42
listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
43
listen_sock.bind(('127.0.0.1', 0))
45
client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
46
client_sock.connect(listen_sock.getsockname())
47
server_sock, addr = listen_sock.accept()
49
return server_sock, client_sock
52
class TCPClient(object):
57
def connect(self, addr):
58
if self.sock is not None:
59
raise AssertionError('Already connected to %r'
60
% (self.sock.getsockname(),))
61
self.sock = osutils.connect_socket(addr)
64
if self.sock is not None:
66
self.sock.shutdown(socket.SHUT_RDWR)
68
except socket.error as e:
69
if e.errno in (errno.EBADF, errno.ENOTCONN, errno.ECONNRESET):
70
# Right, the socket is already down
77
return self.sock.sendall(s)
79
def read(self, bufsize=4096):
80
return self.sock.recv(bufsize)
83
class TCPConnectionHandler(socketserver.BaseRequestHandler):
87
self.handle_connection()
89
self.handle_connection()
92
# TODO: We should be buffering any extra data sent, etc. However, in
93
# practice, we don't send extra content, so we haven't bothered
94
# to implement it yet.
95
req = self.request.recv(4096)
96
# An empty string is allowed, to indicate the end of the connection
97
if not req or (req.endswith(b'\n') and req.count(b'\n') == 1):
99
raise ValueError('[%r] not a simple line' % (req,))
101
def handle_connection(self):
102
req = self.readline()
105
elif req == b'ping\n':
106
self.request.sendall(b'pong\n')
108
raise ValueError('[%s] not understood' % req)
111
class TestTCPServerInAThread(tests.TestCase):
114
(name, {'server_class': getattr(test_server, name)})
116
('TestingTCPServer', 'TestingThreadingTCPServer')]
118
def get_server(self, server_class=None, connection_handler_class=None):
119
if server_class is not None:
120
self.server_class = server_class
121
if connection_handler_class is None:
122
connection_handler_class = TCPConnectionHandler
123
server = test_server.TestingTCPServerInAThread(
124
('localhost', 0), self.server_class, connection_handler_class)
125
server.start_server()
126
self.addCleanup(server.stop_server)
129
def get_client(self):
131
self.addCleanup(client.disconnect)
134
def get_server_connection(self, server, conn_rank):
135
return server.server.clients[conn_rank]
137
def assertClientAddr(self, client, server, conn_rank):
138
conn = self.get_server_connection(server, conn_rank)
139
self.assertEqual(client.sock.getsockname(), conn[1])
141
def test_start_stop(self):
142
server = self.get_server()
143
client = self.get_client()
145
# since the server doesn't accept connections anymore attempting to
146
# connect should fail
147
client = self.get_client()
148
self.assertRaises(socket.error,
149
client.connect, (server.host, server.port))
151
def test_client_talks_server_respond(self):
152
server = self.get_server()
153
client = self.get_client()
154
client.connect((server.host, server.port))
155
self.assertIs(None, client.write(b'ping\n'))
157
self.assertClientAddr(client, server, 0)
158
self.assertEqual(b'pong\n', resp)
160
def test_server_fails_to_start(self):
161
class CantStart(Exception):
164
class CantStartServer(test_server.TestingTCPServer):
166
def server_bind(self):
169
# The exception is raised in the main thread
170
self.assertRaises(CantStart,
171
self.get_server, server_class=CantStartServer)
173
def test_server_fails_while_serving_or_stopping(self):
174
class CantConnect(Exception):
177
class FailingConnectionHandler(TCPConnectionHandler):
182
server = self.get_server(
183
connection_handler_class=FailingConnectionHandler)
184
# The server won't fail until a client connect
185
client = self.get_client()
186
client.connect((server.host, server.port))
187
# We make sure the server wants to handle a request, but the request is
188
# guaranteed to fail. However, the server should make sure that the
189
# connection gets closed, and stop_server should then raise the
190
# original exception.
191
client.write(b'ping\n')
193
self.assertEqual(b'', client.read())
194
except socket.error as e:
195
# On Windows, failing during 'handle' means we get
196
# 'forced-close-of-connection'. Possibly because we haven't
197
# processed the write request before we close the socket.
198
WSAECONNRESET = 10054
199
if e.errno in (WSAECONNRESET,):
201
# Now the server has raised the exception in its own thread
202
self.assertRaises(CantConnect, server.stop_server)
204
def test_server_crash_while_responding(self):
205
# We want to ensure the exception has been caught
206
caught = threading.Event()
208
# The thread that will serve the client, this needs to be an attribute
209
# so the handler below can modify it when it's executed (it's
210
# instantiated when the request is processed)
211
self.connection_thread = None
213
class FailToRespond(Exception):
216
class FailingDuringResponseHandler(TCPConnectionHandler):
218
# We use 'request' instead of 'self' below because the test matters
219
# more and we need a container to properly set connection_thread.
220
def handle_connection(request):
222
# Capture the thread and make it use 'caught' so we can wait on
223
# the event that will be set when the exception is caught. We
224
# also capture the thread to know where to look.
225
self.connection_thread = threading.currentThread()
226
self.connection_thread.set_sync_event(caught)
227
raise FailToRespond()
229
server = self.get_server(
230
connection_handler_class=FailingDuringResponseHandler)
231
client = self.get_client()
232
client.connect((server.host, server.port))
233
client.write(b'ping\n')
234
# Wait for the exception to be caught
236
self.assertEqual(b'', client.read()) # connection closed
237
# Check that the connection thread did catch the exception,
238
# http://pad.lv/869366 was wrongly checking the server thread which
239
# works for TestingTCPServer where the connection is handled in the
240
# same thread than the server one but was racy for
241
# TestingThreadingTCPServer. Since the connection thread detaches
242
# itself before handling the request, we are guaranteed that the
243
# exception won't leak into the server thread anymore.
244
self.assertRaises(FailToRespond,
245
self.connection_thread.pending_exception)
247
def test_exception_swallowed_while_serving(self):
248
# We need to ensure the exception has been caught
249
caught = threading.Event()
251
# The thread that will serve the client, this needs to be an attribute
252
# so the handler below can access it when it's executed (it's
253
# instantiated when the request is processed)
254
self.connection_thread = None
256
class CantServe(Exception):
259
class FailingWhileServingConnectionHandler(TCPConnectionHandler):
261
# We use 'request' instead of 'self' below because the test matters
262
# more and we need a container to properly set connection_thread.
264
# Capture the thread and make it use 'caught' so we can wait on
265
# the event that will be set when the exception is caught. We
266
# also capture the thread to know where to look.
267
self.connection_thread = threading.currentThread()
268
self.connection_thread.set_sync_event(caught)
271
server = self.get_server(
272
connection_handler_class=FailingWhileServingConnectionHandler)
273
self.assertEqual(True, server.server.serving)
274
# Install the exception swallower
275
server.set_ignored_exceptions(CantServe)
276
client = self.get_client()
277
# Connect to the server so the exception is raised there
278
client.connect((server.host, server.port))
279
# Wait for the exception to be caught
281
self.assertEqual(b'', client.read()) # connection closed
282
# The connection wasn't served properly but the exception should have
283
# been swallowed (see test_server_crash_while_responding remark about
284
# http://pad.lv/869366 explaining why we can't check the server thread
285
# here). More precisely, the exception *has* been caught and captured
286
# but it is cleared when joining the thread (or trying to acquire the
287
# exception) and as such won't propagate to the server thread.
288
self.assertIs(None, self.connection_thread.pending_exception())
289
self.assertIs(None, server.pending_exception())
291
def test_handle_request_closes_if_it_doesnt_process(self):
292
server = self.get_server()
293
client = self.get_client()
294
server.server.serving = False
296
client.connect((server.host, server.port))
297
self.assertEqual(b'', client.read())
298
except socket.error as e:
299
if e.errno != errno.ECONNRESET:
303
class TestTestingSmartServer(tests.TestCase):
305
def test_sets_client_timeout(self):
306
server = test_server.TestingSmartServer(
307
('localhost', 0), None, None,
308
root_client_path='/no-such-client/path')
309
self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
310
server._client_timeout)
311
sock = socket.socket()
312
h = server._make_handler(sock)
313
self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
317
class FakeServer(object):
318
"""Minimal implementation to pass to TestingSmartConnectionHandler"""
319
backing_transport = None
320
root_client_path = '/'
323
class TestTestingSmartConnectionHandler(tests.TestCase):
325
def test_connection_timeout_suppressed(self):
326
self.overrideAttr(test_server, '_DEFAULT_TESTING_CLIENT_TIMEOUT', 0.01)
328
server_sock, client_sock = portable_socket_pair()
329
# This should timeout quickly, but not generate an exception.
330
test_server.TestingSmartConnectionHandler(
331
server_sock, server_sock.getpeername(), s)
333
def test_connection_shutdown_while_serving_no_error(self):
335
server_sock, client_sock = portable_socket_pair()
337
class ShutdownConnectionHandler(
338
test_server.TestingSmartConnectionHandler):
340
def _build_protocol(self):
342
return super(ShutdownConnectionHandler, self)._build_protocol()
343
# This should trigger shutdown after the entering _build_protocol, and
344
# we should exit cleanly, without raising an exception.
345
ShutdownConnectionHandler(server_sock, server_sock.getpeername(), s)