617
659
super(TestSmartServerStreamMedium, self).setUp()
618
660
self.overrideEnv('BZR_NO_SMART_VFS', None)
620
def portable_socket_pair(self):
621
"""Return a pair of TCP sockets connected to each other.
623
Unlike socket.socketpair, this should work on Windows.
625
listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
626
listen_sock.bind(('127.0.0.1', 0))
627
listen_sock.listen(1)
628
client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
629
client_sock.connect(listen_sock.getsockname())
630
server_sock, addr = listen_sock.accept()
632
return server_sock, client_sock
662
def create_pipe_medium(self, to_server, from_server, transport,
664
"""Create a new SmartServerPipeStreamMedium."""
665
return medium.SmartServerPipeStreamMedium(to_server, from_server,
666
transport, timeout=timeout)
668
def create_pipe_context(self, to_server_bytes, transport):
669
"""Create a SmartServerSocketStreamMedium.
671
This differes from create_pipe_medium, in that we initialize the
672
request that is sent to the server, and return the StringIO class that
673
will hold the response.
675
to_server = StringIO(to_server_bytes)
676
from_server = StringIO()
677
m = self.create_pipe_medium(to_server, from_server, transport)
678
return m, from_server
680
def create_socket_medium(self, server_sock, transport, timeout=4.0):
681
"""Initialize a new medium.SmartServerSocketStreamMedium."""
682
return medium.SmartServerSocketStreamMedium(server_sock, transport,
685
def create_socket_context(self, transport, timeout=4.0):
686
"""Create a new SmartServerSocketStreamMedium with default context.
688
This will call portable_socket_pair and pass the server side to
689
create_socket_medium along with transport.
690
It then returns the client_sock and the server.
692
server_sock, client_sock = portable_socket_pair()
693
server = self.create_socket_medium(server_sock, transport,
695
return server, client_sock
634
697
def test_smart_query_version(self):
635
698
"""Feed a canned query version to a server"""
636
699
# wire-to-wire, using the whole stack
637
to_server = StringIO('hello\n')
638
from_server = StringIO()
639
700
transport = local.LocalTransport(urlutils.local_path_to_url('/'))
640
server = medium.SmartServerPipeStreamMedium(
641
to_server, from_server, transport)
701
server, from_server = self.create_pipe_context('hello\n', transport)
642
702
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
643
703
from_server.write)
644
704
server._serve_one_request(smart_protocol)
696
750
def test_socket_stream_with_bulk_data(self):
697
751
sample_request_bytes = 'command\n9\nbulk datadone\n'
698
server_sock, client_sock = self.portable_socket_pair()
699
server = medium.SmartServerSocketStreamMedium(
752
server, client_sock = self.create_socket_context(None)
701
753
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
754
client_sock.sendall(sample_request_bytes)
703
755
server._serve_one_request(sample_protocol)
756
server._disconnect_client()
705
757
self.assertEqual('', client_sock.recv(1))
706
758
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
707
759
self.assertFalse(server.finished)
709
761
def test_pipe_like_stream_shutdown_detection(self):
710
to_server = StringIO('')
711
from_server = StringIO()
712
server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
762
server, _ = self.create_pipe_context('', None)
713
763
server._serve_one_request(SampleRequest('x'))
714
764
self.assertTrue(server.finished)
716
766
def test_socket_stream_shutdown_detection(self):
717
server_sock, client_sock = self.portable_socket_pair()
767
server, client_sock = self.create_socket_context(None)
718
768
client_sock.close()
719
server = medium.SmartServerSocketStreamMedium(
721
769
server._serve_one_request(SampleRequest('x'))
722
770
self.assertTrue(server.finished)
858
897
self.assertTrue(server.finished)
860
899
def test_pipe_like_stream_keyboard_interrupt_handling(self):
861
to_server = StringIO('')
862
from_server = StringIO()
863
server = medium.SmartServerPipeStreamMedium(
864
to_server, from_server, None)
900
server, from_server = self.create_pipe_context('', None)
865
901
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
866
902
self.assertRaises(
867
903
KeyboardInterrupt, server._serve_one_request, fake_protocol)
868
904
self.assertEqual('', from_server.getvalue())
870
906
def test_socket_stream_keyboard_interrupt_handling(self):
871
server_sock, client_sock = self.portable_socket_pair()
872
server = medium.SmartServerSocketStreamMedium(
907
server, client_sock = self.create_socket_context(None)
874
908
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
875
909
self.assertRaises(
876
910
KeyboardInterrupt, server._serve_one_request, fake_protocol)
911
server._disconnect_client()
878
912
self.assertEqual('', client_sock.recv(1))
880
914
def build_protocol_pipe_like(self, bytes):
881
to_server = StringIO(bytes)
882
from_server = StringIO()
883
server = medium.SmartServerPipeStreamMedium(
884
to_server, from_server, None)
915
server, _ = self.create_pipe_context(bytes, None)
885
916
return server._build_protocol()
887
918
def build_protocol_socket(self, bytes):
888
server_sock, client_sock = self.portable_socket_pair()
889
server = medium.SmartServerSocketStreamMedium(
919
server, client_sock = self.create_socket_context(None)
891
920
client_sock.sendall(bytes)
892
921
client_sock.close()
893
922
return server._build_protocol()
933
962
server_protocol = self.build_protocol_socket('bzr request 2\n')
934
963
self.assertProtocolTwo(server_protocol)
965
def test__build_protocol_returns_if_stopping(self):
966
# _build_protocol should notice that we are stopping, and return
967
# without waiting for bytes from the client.
968
server, client_sock = self.create_socket_context(None)
969
server._stop_gracefully()
970
self.assertIs(None, server._build_protocol())
972
def test_socket_set_timeout(self):
973
server, _ = self.create_socket_context(None, timeout=1.23)
974
self.assertEqual(1.23, server._client_timeout)
976
def test_pipe_set_timeout(self):
977
server = self.create_pipe_medium(None, None, None,
979
self.assertEqual(1.23, server._client_timeout)
981
def test_socket_wait_for_bytes_with_timeout_with_data(self):
982
server, client_sock = self.create_socket_context(None)
983
client_sock.sendall('data\n')
984
# This should not block or consume any actual content
985
self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
986
data = server.read_bytes(5)
987
self.assertEqual('data\n', data)
989
def test_socket_wait_for_bytes_with_timeout_no_data(self):
990
server, client_sock = self.create_socket_context(None)
991
# This should timeout quickly, reporting that there wasn't any data
992
self.assertRaises(errors.ConnectionTimeout,
993
server._wait_for_bytes_with_timeout, 0.01)
995
data = server.read_bytes(1)
996
self.assertEqual('', data)
998
def test_socket_wait_for_bytes_with_timeout_closed(self):
999
server, client_sock = self.create_socket_context(None)
1000
# With the socket closed, this should return right away.
1001
# It seems select.select() returns that you *can* read on the socket,
1002
# even though it closed. Presumably as a way to tell it is closed?
1003
# Testing shows that without sock.close() this times-out failing the
1004
# test, but with it, it returns False immediately.
1006
self.assertFalse(server._wait_for_bytes_with_timeout(10))
1007
data = server.read_bytes(1)
1008
self.assertEqual('', data)
1010
def test_socket_wait_for_bytes_with_shutdown(self):
1011
server, client_sock = self.create_socket_context(None)
1013
# Override the _timer functionality, so that time never increments,
1014
# this way, we can be sure we stopped because of the flag, and not
1015
# because of a timeout, etc.
1016
server._timer = lambda: t
1017
server._client_poll_timeout = 0.1
1018
server._stop_gracefully()
1019
server._wait_for_bytes_with_timeout(1.0)
1021
def test_socket_serve_timeout_closes_socket(self):
1022
server, client_sock = self.create_socket_context(None, timeout=0.1)
1023
# This should timeout quickly, and then close the connection so that
1024
# client_sock recv doesn't block.
1026
self.assertEqual('', client_sock.recv(1))
1028
def test_pipe_wait_for_bytes_with_timeout_with_data(self):
1029
# We intentionally use a real pipe here, so that we can 'select' on it.
1030
# You can't select() on a StringIO
1031
(r_server, w_client) = os.pipe()
1032
self.addCleanup(os.close, w_client)
1033
with os.fdopen(r_server, 'rb') as rf_server:
1034
server = self.create_pipe_medium(
1035
rf_server, None, None)
1036
os.write(w_client, 'data\n')
1037
# This should not block or consume any actual content
1038
server._wait_for_bytes_with_timeout(0.1)
1039
data = server.read_bytes(5)
1040
self.assertEqual('data\n', data)
1042
def test_pipe_wait_for_bytes_with_timeout_no_data(self):
1043
# We intentionally use a real pipe here, so that we can 'select' on it.
1044
# You can't select() on a StringIO
1045
(r_server, w_client) = os.pipe()
1046
# We can't add an os.close cleanup here, because we need to control
1047
# when the file handle gets closed ourselves.
1048
with os.fdopen(r_server, 'rb') as rf_server:
1049
server = self.create_pipe_medium(
1050
rf_server, None, None)
1051
if sys.platform == 'win32':
1052
# Windows cannot select() on a pipe, so we just always return
1053
server._wait_for_bytes_with_timeout(0.01)
1055
self.assertRaises(errors.ConnectionTimeout,
1056
server._wait_for_bytes_with_timeout, 0.01)
1058
data = server.read_bytes(5)
1059
self.assertEqual('', data)
1061
def test_pipe_wait_for_bytes_no_fileno(self):
1062
server, _ = self.create_pipe_context('', None)
1063
# Our file doesn't support polling, so we should always just return
1064
# 'you have data to consume.
1065
server._wait_for_bytes_with_timeout(0.01)
937
1068
class TestGetProtocolFactoryForBytes(tests.TestCase):
938
1069
"""_get_protocol_factory_for_bytes identifies the protocol factory a server
969
1100
class TestSmartTCPServer(tests.TestCase):
1102
def make_server(self):
1103
"""Create a SmartTCPServer that we can exercise.
1105
Note: we don't use SmartTCPServer_for_testing because the testing
1106
version overrides lots of functionality like 'serve', and we want to
1107
test the raw service.
1109
This will start the server in another thread, and wait for it to
1110
indicate it has finished starting up.
1112
:return: (server, server_thread)
1114
t = _mod_transport.get_transport_from_url('memory:///')
1115
server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
1116
server._ACCEPT_TIMEOUT = 0.1
1117
# We don't use 'localhost' because that might be an IPv6 address.
1118
server.start_server('127.0.0.1', 0)
1119
server_thread = threading.Thread(target=server.serve,
1121
server_thread.start()
1122
# Ensure this gets called at some point
1123
self.addCleanup(server._stop_gracefully)
1124
server._started.wait()
1125
return server, server_thread
1127
def ensure_client_disconnected(self, client_sock):
1128
"""Ensure that a socket is closed, discarding all errors."""
1134
def connect_to_server(self, server):
1135
"""Create a client socket that can talk to the server."""
1136
client_sock = socket.socket()
1137
server_info = server._server_socket.getsockname()
1138
client_sock.connect(server_info)
1139
self.addCleanup(self.ensure_client_disconnected, client_sock)
1142
def connect_to_server_and_hangup(self, server):
1143
"""Connect to the server, and then hang up.
1144
That way it doesn't sit waiting for 'accept()' to timeout.
1146
# If the server has already signaled that the socket is closed, we
1147
# don't need to try to connect to it. Not being set, though, the server
1148
# might still close the socket while we try to connect to it. So we
1149
# still have to catch the exception.
1150
if server._stopped.isSet():
1153
client_sock = self.connect_to_server(server)
1155
except socket.error, e:
1156
# If the server has hung up already, that is fine.
1159
def say_hello(self, client_sock):
1160
"""Send the 'hello' smart RPC, and expect the response."""
1161
client_sock.send('hello\n')
1162
self.assertEqual('ok\x012\n', client_sock.recv(5))
1164
def shutdown_server_cleanly(self, server, server_thread):
1165
server._stop_gracefully()
1166
self.connect_to_server_and_hangup(server)
1167
server._stopped.wait()
1168
server._fully_stopped.wait()
1169
server_thread.join()
971
1171
def test_get_error_unexpected(self):
972
1172
"""Error reported by server with no specific representation"""
973
1173
self.overrideEnv('BZR_NO_SMART_VFS', None)
991
1191
t.get, 'something')
992
1192
self.assertContainsRe(str(err), 'some random exception')
1194
def test_propagates_timeout(self):
1195
server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
1196
server_sock, client_sock = portable_socket_pair()
1197
handler = server._make_handler(server_sock)
1198
self.assertEqual(1.23, handler._client_timeout)
1200
def test_serve_conn_tracks_connections(self):
1201
server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
1202
server_sock, client_sock = portable_socket_pair()
1203
server.serve_conn(server_sock, '-%s' % (self.id(),))
1204
self.assertEqual(1, len(server._active_connections))
1205
# We still want to talk on the connection. Polling should indicate it
1207
server._poll_active_connections()
1208
self.assertEqual(1, len(server._active_connections))
1209
# Closing the socket will end the active thread, and polling will
1210
# notice and remove it from the active set.
1212
server._poll_active_connections(0.1)
1213
self.assertEqual(0, len(server._active_connections))
1215
def test_serve_closes_out_finished_connections(self):
1216
server, server_thread = self.make_server()
1217
# The server is started, connect to it.
1218
client_sock = self.connect_to_server(server)
1219
# We send and receive on the connection, so that we know the
1220
# server-side has seen the connect, and started handling the
1222
self.say_hello(client_sock)
1223
self.assertEqual(1, len(server._active_connections))
1224
# Grab a handle to the thread that is processing our request
1225
_, server_side_thread = server._active_connections[0]
1226
# Close the connection, ask the server to stop, and wait for the
1227
# server to stop, as well as the thread that was servicing the
1230
# Wait for the server-side request thread to notice we are closed.
1231
server_side_thread.join()
1232
# Stop the server, it should notice the connection has finished.
1233
self.shutdown_server_cleanly(server, server_thread)
1234
# The server should have noticed that all clients are gone before
1236
self.assertEqual(0, len(server._active_connections))
1238
def test_serve_reaps_finished_connections(self):
1239
server, server_thread = self.make_server()
1240
client_sock1 = self.connect_to_server(server)
1241
# We send and receive on the connection, so that we know the
1242
# server-side has seen the connect, and started handling the
1244
self.say_hello(client_sock1)
1245
server_handler1, server_side_thread1 = server._active_connections[0]
1246
client_sock1.close()
1247
server_side_thread1.join()
1248
# By waiting until the first connection is fully done, the server
1249
# should notice after another connection that the first has finished.
1250
client_sock2 = self.connect_to_server(server)
1251
self.say_hello(client_sock2)
1252
server_handler2, server_side_thread2 = server._active_connections[-1]
1253
# There is a race condition. We know that client_sock2 has been
1254
# registered, but not that _poll_active_connections has been called. We
1255
# know that it will be called before the server will accept a new
1256
# connection, however. So connect one more time, and assert that we
1257
# either have 1 or 2 active connections (never 3), and that the 'first'
1258
# connection is not connection 1
1259
client_sock3 = self.connect_to_server(server)
1260
self.say_hello(client_sock3)
1261
# Copy the list, so we don't have it mutating behind our back
1262
conns = list(server._active_connections)
1263
self.assertEqual(2, len(conns))
1264
self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
1265
self.assertEqual((server_handler2, server_side_thread2), conns[0])
1266
client_sock2.close()
1267
client_sock3.close()
1268
self.shutdown_server_cleanly(server, server_thread)
1270
def test_graceful_shutdown_waits_for_clients_to_stop(self):
1271
server, server_thread = self.make_server()
1272
# We need something big enough that it won't fit in a single recv. So
1273
# the server thread gets blocked writing content to the client until we
1274
# finish reading on the client.
1275
server.backing_transport.put_bytes('bigfile',
1277
client_sock = self.connect_to_server(server)
1278
self.say_hello(client_sock)
1279
_, server_side_thread = server._active_connections[0]
1280
# Start the RPC, but don't finish reading the response
1281
client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
1282
'base', client_sock)
1283
client_client = client._SmartClient(client_medium)
1284
resp, response_handler = client_client.call_expecting_body('get',
1286
self.assertEqual(('ok',), resp)
1287
# Ask the server to stop gracefully, and wait for it.
1288
server._stop_gracefully()
1289
self.connect_to_server_and_hangup(server)
1290
server._stopped.wait()
1291
# It should not be accepting another connection.
1292
self.assertRaises(socket.error, self.connect_to_server, server)
1293
# It should also not be fully stopped
1294
server._fully_stopped.wait(0.01)
1295
self.assertFalse(server._fully_stopped.isSet())
1296
response_handler.read_body_bytes()
1298
server_side_thread.join()
1299
server_thread.join()
1300
self.assertTrue(server._fully_stopped.isSet())
1301
log = self.get_log()
1302
self.assertThat(log, DocTestMatches("""\
1303
INFO Requested to stop gracefully
1304
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
1305
INFO Waiting for 1 client(s) to finish
1306
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
1308
def test_stop_gracefully_tells_handlers_to_stop(self):
1309
server, server_thread = self.make_server()
1310
client_sock = self.connect_to_server(server)
1311
self.say_hello(client_sock)
1312
server_handler, server_side_thread = server._active_connections[0]
1313
self.assertFalse(server_handler.finished)
1314
server._stop_gracefully()
1315
self.assertTrue(server_handler.finished)
1317
self.connect_to_server_and_hangup(server)
1318
server_thread.join()
995
1321
class SmartTCPTests(tests.TestCase):
996
1322
"""Tests for connection/end to end behaviour using the TCP server.