1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
1
# Copyright (C) 2006-2010 Canonical Ltd
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
43
from bzrlib.tests import (
43
47
from bzrlib.transport import (
49
from bzrlib.transport.http import SmartClientHTTPMediumRequest
52
56
class StringIOSSHVendor(object):
63
67
return StringIOSSHConnection(self)
66
class StringIOSSHConnection(object):
70
class StringIOSSHConnection(ssh.SSHConnection):
67
71
"""A SSH connection that uses StringIO to buffer writes and answer reads."""
69
73
def __init__(self, vendor):
73
77
self.vendor.calls.append(('close', ))
78
self.vendor.read_from.close()
79
self.vendor.write_to.close()
75
def get_filelike_channels(self):
76
return self.vendor.read_from, self.vendor.write_to
81
def get_sock_or_pipes(self):
82
return 'pipes', (self.vendor.read_from, self.vendor.write_to)
79
85
class _InvalidHostnameFeature(tests.Feature):
243
249
unopened_port = sock.getsockname()[1]
244
250
# having vendor be invalid means that if it tries to connect via the
245
251
# vendor it will blow up.
246
client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
247
username=None, password=None, base='base', vendor="not a vendor",
248
bzr_remote_path='bzr')
252
ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
253
client_medium = medium.SmartSSHClientMedium(
254
'base', ssh_params, "not a vendor")
251
257
def test_ssh_client_connects_on_first_use(self):
254
260
output = StringIO()
255
261
vendor = StringIOSSHVendor(StringIO(), output)
256
client_medium = medium.SmartSSHClientMedium(
257
'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
262
ssh_params = medium.SSHParams(
263
'a hostname', 'a port', 'a username', 'a password', 'bzr')
264
client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
259
265
client_medium._accept_bytes('abc')
260
266
self.assertEqual('abc', output.getvalue())
261
267
self.assertEqual([('connect_ssh', 'a username', 'a password',
269
275
output = StringIO()
270
276
vendor = StringIOSSHVendor(StringIO(), output)
271
client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
272
'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
277
ssh_params = medium.SSHParams(
278
'a hostname', 'a port', 'a username', 'a password',
279
bzr_remote_path='fugly')
280
client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
273
281
client_medium._accept_bytes('abc')
274
282
self.assertEqual('abc', output.getvalue())
275
283
self.assertEqual([('connect_ssh', 'a username', 'a password',
284
292
output = StringIO()
285
293
vendor = StringIOSSHVendor(input, output)
286
294
client_medium = medium.SmartSSHClientMedium(
287
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
295
'base', medium.SSHParams('a hostname'), vendor)
288
296
client_medium._accept_bytes('abc')
289
297
client_medium.disconnect()
290
298
self.assertTrue(input.closed)
305
313
output = StringIO()
306
314
vendor = StringIOSSHVendor(input, output)
307
315
client_medium = medium.SmartSSHClientMedium(
308
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
316
'base', medium.SSHParams('a hostname'), vendor)
309
317
client_medium._accept_bytes('abc')
310
318
client_medium.disconnect()
311
319
# the disconnect has closed output, so we need a new output for the
334
342
# Doing a disconnect on a new (and thus unconnected) SSH medium
335
343
# does not fail. It's ok to disconnect an unconnected medium.
336
344
client_medium = medium.SmartSSHClientMedium(
337
None, base='base', bzr_remote_path='bzr')
345
'base', medium.SSHParams(None))
338
346
client_medium.disconnect()
340
348
def test_ssh_client_raises_on_read_when_not_connected(self):
341
349
# Doing a read on a new (and thus unconnected) SSH medium raises
342
350
# MediumNotConnected.
343
351
client_medium = medium.SmartSSHClientMedium(
344
None, base='base', bzr_remote_path='bzr')
352
'base', medium.SSHParams(None))
345
353
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
347
355
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
359
367
output.flush = logging_flush
360
368
vendor = StringIOSSHVendor(input, output)
361
369
client_medium = medium.SmartSSHClientMedium(
362
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
370
'base', medium.SSHParams('a hostname'), vendor=vendor)
363
371
# this call is here to ensure we only flush once, not on every
364
372
# _accept_bytes call.
365
373
client_medium._accept_bytes('abc')
557
565
self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
560
class RemoteTransportTests(TestCaseWithSmartMedium):
568
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
562
570
def test_plausible_url(self):
563
571
self.assert_(self.get_url().startswith('bzr://'))
968
976
def external_url(self):
970
def get_bytes(self, path):
971
979
raise Exception("some random exception from inside server")
972
smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
973
smart_server.start_background_thread('-' + self.id())
975
transport = remote.RemoteTCPTransport(smart_server.get_url())
976
err = self.assertRaises(errors.UnknownErrorFromSmartServer,
977
transport.get, 'something')
978
self.assertContainsRe(str(err), 'some random exception')
979
transport.disconnect()
981
smart_server.stop_background_thread()
981
class FlakyServer(test_server.SmartTCPServer_for_testing):
982
def get_backing_transport(self, backing_transport_server):
983
return FlakyTransport()
985
smart_server = FlakyServer()
986
smart_server.start_server()
987
self.addCleanup(smart_server.stop_server)
988
t = remote.RemoteTCPTransport(smart_server.get_url())
989
self.addCleanup(t.disconnect)
990
err = self.assertRaises(errors.UnknownErrorFromSmartServer,
992
self.assertContainsRe(str(err), 'some random exception')
984
995
class SmartTCPTests(tests.TestCase):
985
996
"""Tests for connection/end to end behaviour using the TCP server.
987
All of these tests are run with a server running on another thread serving
998
All of these tests are run with a server running in another thread serving
988
999
a MemoryTransport, and a connection to it already open.
990
the server is obtained by calling self.setUpServer(readonly=False).
1001
the server is obtained by calling self.start_server(readonly=False).
993
def setUpServer(self, readonly=False, backing_transport=None):
1004
def start_server(self, readonly=False, backing_transport=None):
994
1005
"""Setup the server.
996
1007
:param readonly: Create a readonly server.
998
1009
# NB: Tests using this fall into two categories: tests of the server,
999
1010
# tests wanting a server. The latter should be updated to use
1000
1011
# self.vfs_transport_factory etc.
1001
if not backing_transport:
1012
if backing_transport is None:
1002
1013
mem_server = memory.MemoryServer()
1004
self.addCleanup(mem_server.tearDown)
1014
mem_server.start_server()
1015
self.addCleanup(mem_server.stop_server)
1005
1016
self.permit_url(mem_server.get_url())
1006
self.backing_transport = get_transport(mem_server.get_url())
1017
self.backing_transport = transport.get_transport(
1018
mem_server.get_url())
1008
1020
self.backing_transport = backing_transport
1010
1022
self.real_backing_transport = self.backing_transport
1011
self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
1023
self.backing_transport = transport.get_transport(
1024
"readonly+" + self.backing_transport.abspath('.'))
1012
1025
self.server = server.SmartTCPServer(self.backing_transport)
1026
self.server.start_server('127.0.0.1', 0)
1013
1027
self.server.start_background_thread('-' + self.id())
1014
1028
self.transport = remote.RemoteTCPTransport(self.server.get_url())
1015
self.addCleanup(self.tearDownServer)
1029
self.addCleanup(self.stop_server)
1016
1030
self.permit_url(self.server.get_url())
1018
def tearDownServer(self):
1032
def stop_server(self):
1033
"""Disconnect the client and stop the server.
1035
This must be re-entrant as some tests will call it explicitly in
1036
addition to the normal cleanup.
1019
1038
if getattr(self, 'transport', None):
1020
1039
self.transport.disconnect()
1021
1040
del self.transport
1027
1046
class TestServerSocketUsage(SmartTCPTests):
1029
def test_server_setup_teardown(self):
1030
"""It should be safe to teardown the server with no requests."""
1032
server = self.server
1033
transport = remote.RemoteTCPTransport(self.server.get_url())
1034
self.tearDownServer()
1035
self.assertRaises(errors.ConnectionError, transport.has, '.')
1048
def test_server_start_stop(self):
1049
"""It should be safe to stop the server with no requests."""
1051
t = remote.RemoteTCPTransport(self.server.get_url())
1053
self.assertRaises(errors.ConnectionError, t.has, '.')
1037
1055
def test_server_closes_listening_sock_on_shutdown_after_request(self):
1038
1056
"""The server should close its listening socket when it's stopped."""
1040
server = self.server
1058
server_url = self.server.get_url()
1041
1059
self.transport.has('.')
1042
self.tearDownServer()
1043
1061
# if the listening socket has closed, we should get a BADFD error
1044
1062
# when connecting, rather than a hang.
1045
transport = remote.RemoteTCPTransport(server.get_url())
1046
self.assertRaises(errors.ConnectionError, transport.has, '.')
1063
t = remote.RemoteTCPTransport(server_url)
1064
self.assertRaises(errors.ConnectionError, t.has, '.')
1049
1067
class WritableEndToEndTests(SmartTCPTests):
1130
1148
def test_mkdir_error_readonly(self):
1131
1149
"""TransportNotPossible should be preserved from the backing transport."""
1132
1150
self._captureVar('BZR_NO_SMART_VFS', None)
1133
self.setUpServer(readonly=True)
1151
self.start_server(readonly=True)
1134
1152
self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1146
1164
self.hook_calls = []
1147
1165
server.SmartTCPServer.hooks.install_named_hook('server_started',
1148
1166
self.capture_server_call, None)
1150
1168
# at this point, the server will be starting a thread up.
1151
1169
# there is no indicator at the moment, so bodge it by doing a request.
1152
1170
self.transport.has('.')
1160
1178
self.hook_calls = []
1161
1179
server.SmartTCPServer.hooks.install_named_hook('server_started',
1162
1180
self.capture_server_call, None)
1163
self.setUpServer(backing_transport=get_transport("."))
1181
self.start_server(backing_transport=transport.get_transport("."))
1164
1182
# at this point, the server will be starting a thread up.
1165
1183
# there is no indicator at the moment, so bodge it by doing a request.
1166
1184
self.transport.has('.')
1176
1194
self.hook_calls = []
1177
1195
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1178
1196
self.capture_server_call, None)
1180
1198
result = [([self.backing_transport.base], self.transport.base)]
1181
1199
# check the stopping message isn't emitted up front.
1182
1200
self.assertEqual([], self.hook_calls)
1193
1211
self.hook_calls = []
1194
1212
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1195
1213
self.capture_server_call, None)
1196
self.setUpServer(backing_transport=get_transport("."))
1214
self.start_server(backing_transport=transport.get_transport("."))
1198
1216
[self.backing_transport.base, self.backing_transport.external_url()]
1199
1217
, self.transport.base)]
1338
1356
class RemoteTransportRegistration(tests.TestCase):
1340
1358
def test_registration(self):
1341
t = get_transport('bzr+ssh://example.com/path')
1359
t = transport.get_transport('bzr+ssh://example.com/path')
1342
1360
self.assertIsInstance(t, remote.RemoteSSHTransport)
1343
1361
self.assertEqual('example.com', t._host)
1345
1363
def test_bzr_https(self):
1346
1364
# https://bugs.launchpad.net/bzr/+bug/128456
1347
t = get_transport('bzr+https://example.com/path')
1365
t = transport.get_transport('bzr+https://example.com/path')
1348
1366
self.assertIsInstance(t, remote.RemoteHTTPTransport)
1349
1367
self.assertStartsWith(
1350
1368
t._http_transport.base,
2856
2874
self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2858
2876
def assertWriteCount(self, expected_count):
2877
# self.writes can be quite large; don't show the whole thing
2859
2878
self.assertEqual(
2860
2879
expected_count, len(self.writes),
2861
"Too many writes: %r" % (self.writes,))
2880
"Too many writes: %d, expected %d" % (len(self.writes), expected_count))
2863
2882
def test_send_error_writes_just_once(self):
2864
2883
"""An error response is written to the medium all at once."""
2887
2906
response = _mod_request.SuccessfulSmartServerResponse(
2888
2907
('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2889
2908
self.responder.send_response(response)
2890
# We will write just once, despite the multiple chunks, due to
2892
self.assertWriteCount(1)
2894
def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
2895
"""When there are many bytes (>1MB), multiple writes will occur rather
2896
than buffering indefinitely.
2898
# Construct a response with stream with ~1.5MB in it. This should
2899
# trigger 2 writes, but not 3
2900
onekib = '12345678' * 128
2901
body_stream = [onekib] * (1024 + 512)
2902
response = _mod_request.SuccessfulSmartServerResponse(
2903
('arg', 'arg'), body_stream=body_stream)
2904
self.responder.send_response(response)
2905
# The write buffer is flushed every 100 buffered writes, so we expect 2
2907
self.assertWriteCount(2)
2909
# Per the discussion in bug 590638 we flush once after the header and
2910
# then once after each chunk
2911
self.assertWriteCount(3)
2910
2914
class TestSmartClientUnicode(tests.TestCase):
3518
3522
def test_smart_http_medium_request_accept_bytes(self):
3519
3523
medium = FakeHTTPMedium()
3520
request = SmartClientHTTPMediumRequest(medium)
3524
request = http.SmartClientHTTPMediumRequest(medium)
3521
3525
request.accept_bytes('abc')
3522
3526
request.accept_bytes('def')
3523
3527
self.assertEqual(None, medium.written_request)