/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Martin Pool
  • Date: 2010-10-08 04:38:25 UTC
  • mfrom: (5462 +trunk)
  • mto: This revision was merged to the branch mainline in revision 5478.
  • Revision ID: mbp@sourcefrog.net-20101008043825-b181r8bo5r3qwb6j
merge trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
28
28
        errors,
29
29
        osutils,
30
30
        tests,
 
31
        transport,
31
32
        urlutils,
32
33
        )
33
34
from bzrlib.smart import (
39
40
        server,
40
41
        vfs,
41
42
)
42
 
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
 
43
from bzrlib.tests import (
 
44
    test_smart,
 
45
    test_server,
 
46
    )
43
47
from bzrlib.transport import (
44
 
        get_transport,
 
48
        http,
45
49
        local,
46
50
        memory,
47
51
        remote,
 
52
        ssh,
48
53
        )
49
 
from bzrlib.transport.http import SmartClientHTTPMediumRequest
50
54
 
51
55
 
52
56
class StringIOSSHVendor(object):
63
67
        return StringIOSSHConnection(self)
64
68
 
65
69
 
66
 
class StringIOSSHConnection(object):
 
70
class StringIOSSHConnection(ssh.SSHConnection):
67
71
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
68
72
 
69
73
    def __init__(self, vendor):
71
75
 
72
76
    def close(self):
73
77
        self.vendor.calls.append(('close', ))
 
78
        self.vendor.read_from.close()
 
79
        self.vendor.write_to.close()
74
80
 
75
 
    def get_filelike_channels(self):
76
 
        return self.vendor.read_from, self.vendor.write_to
 
81
    def get_sock_or_pipes(self):
 
82
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
77
83
 
78
84
 
79
85
class _InvalidHostnameFeature(tests.Feature):
243
249
        unopened_port = sock.getsockname()[1]
244
250
        # having vendor be invalid means that if it tries to connect via the
245
251
        # vendor it will blow up.
246
 
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
247
 
            username=None, password=None, base='base', vendor="not a vendor",
248
 
            bzr_remote_path='bzr')
 
252
        ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
 
253
        client_medium = medium.SmartSSHClientMedium(
 
254
            'base', ssh_params, "not a vendor")
249
255
        sock.close()
250
256
 
251
257
    def test_ssh_client_connects_on_first_use(self):
253
259
        # it bytes.
254
260
        output = StringIO()
255
261
        vendor = StringIOSSHVendor(StringIO(), output)
256
 
        client_medium = medium.SmartSSHClientMedium(
257
 
            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
258
 
            'bzr')
 
262
        ssh_params = medium.SSHParams(
 
263
            'a hostname', 'a port', 'a username', 'a password', 'bzr')
 
264
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
259
265
        client_medium._accept_bytes('abc')
260
266
        self.assertEqual('abc', output.getvalue())
261
267
        self.assertEqual([('connect_ssh', 'a username', 'a password',
268
274
        # it bytes.
269
275
        output = StringIO()
270
276
        vendor = StringIOSSHVendor(StringIO(), output)
271
 
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
272
 
            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
 
277
        ssh_params = medium.SSHParams(
 
278
            'a hostname', 'a port', 'a username', 'a password',
 
279
            bzr_remote_path='fugly')
 
280
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
273
281
        client_medium._accept_bytes('abc')
274
282
        self.assertEqual('abc', output.getvalue())
275
283
        self.assertEqual([('connect_ssh', 'a username', 'a password',
284
292
        output = StringIO()
285
293
        vendor = StringIOSSHVendor(input, output)
286
294
        client_medium = medium.SmartSSHClientMedium(
287
 
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
295
            'base', medium.SSHParams('a hostname'), vendor)
288
296
        client_medium._accept_bytes('abc')
289
297
        client_medium.disconnect()
290
298
        self.assertTrue(input.closed)
305
313
        output = StringIO()
306
314
        vendor = StringIOSSHVendor(input, output)
307
315
        client_medium = medium.SmartSSHClientMedium(
308
 
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
316
            'base', medium.SSHParams('a hostname'), vendor)
309
317
        client_medium._accept_bytes('abc')
310
318
        client_medium.disconnect()
311
319
        # the disconnect has closed output, so we need a new output for the
334
342
        # Doing a disconnect on a new (and thus unconnected) SSH medium
335
343
        # does not fail.  It's ok to disconnect an unconnected medium.
336
344
        client_medium = medium.SmartSSHClientMedium(
337
 
            None, base='base', bzr_remote_path='bzr')
 
345
            'base', medium.SSHParams(None))
338
346
        client_medium.disconnect()
339
347
 
340
348
    def test_ssh_client_raises_on_read_when_not_connected(self):
341
349
        # Doing a read on a new (and thus unconnected) SSH medium raises
342
350
        # MediumNotConnected.
343
351
        client_medium = medium.SmartSSHClientMedium(
344
 
            None, base='base', bzr_remote_path='bzr')
 
352
            'base', medium.SSHParams(None))
345
353
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
346
354
                          0)
347
355
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
359
367
        output.flush = logging_flush
360
368
        vendor = StringIOSSHVendor(input, output)
361
369
        client_medium = medium.SmartSSHClientMedium(
362
 
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
370
            'base', medium.SSHParams('a hostname'), vendor=vendor)
363
371
        # this call is here to ensure we only flush once, not on every
364
372
        # _accept_bytes call.
365
373
        client_medium._accept_bytes('abc')
557
565
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
558
566
 
559
567
 
560
 
class RemoteTransportTests(TestCaseWithSmartMedium):
 
568
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
561
569
 
562
570
    def test_plausible_url(self):
563
571
        self.assert_(self.get_url().startswith('bzr://'))
967
975
            base = 'a_url'
968
976
            def external_url(self):
969
977
                return self.base
970
 
            def get_bytes(self, path):
 
978
            def get(self, path):
971
979
                raise Exception("some random exception from inside server")
972
 
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
973
 
        smart_server.start_background_thread('-' + self.id())
974
 
        try:
975
 
            transport = remote.RemoteTCPTransport(smart_server.get_url())
976
 
            err = self.assertRaises(errors.UnknownErrorFromSmartServer,
977
 
                transport.get, 'something')
978
 
            self.assertContainsRe(str(err), 'some random exception')
979
 
            transport.disconnect()
980
 
        finally:
981
 
            smart_server.stop_background_thread()
 
980
 
 
981
        class FlakyServer(test_server.SmartTCPServer_for_testing):
 
982
            def get_backing_transport(self, backing_transport_server):
 
983
                return FlakyTransport()
 
984
 
 
985
        smart_server = FlakyServer()
 
986
        smart_server.start_server()
 
987
        self.addCleanup(smart_server.stop_server)
 
988
        t = remote.RemoteTCPTransport(smart_server.get_url())
 
989
        self.addCleanup(t.disconnect)
 
990
        err = self.assertRaises(errors.UnknownErrorFromSmartServer,
 
991
                                t.get, 'something')
 
992
        self.assertContainsRe(str(err), 'some random exception')
982
993
 
983
994
 
984
995
class SmartTCPTests(tests.TestCase):
985
996
    """Tests for connection/end to end behaviour using the TCP server.
986
997
 
987
 
    All of these tests are run with a server running on another thread serving
 
998
    All of these tests are run with a server running in another thread serving
988
999
    a MemoryTransport, and a connection to it already open.
989
1000
 
990
 
    the server is obtained by calling self.setUpServer(readonly=False).
 
1001
    the server is obtained by calling self.start_server(readonly=False).
991
1002
    """
992
1003
 
993
 
    def setUpServer(self, readonly=False, backing_transport=None):
 
1004
    def start_server(self, readonly=False, backing_transport=None):
994
1005
        """Setup the server.
995
1006
 
996
1007
        :param readonly: Create a readonly server.
998
1009
        # NB: Tests using this fall into two categories: tests of the server,
999
1010
        # tests wanting a server. The latter should be updated to use
1000
1011
        # self.vfs_transport_factory etc.
1001
 
        if not backing_transport:
 
1012
        if backing_transport is None:
1002
1013
            mem_server = memory.MemoryServer()
1003
 
            mem_server.setUp()
1004
 
            self.addCleanup(mem_server.tearDown)
 
1014
            mem_server.start_server()
 
1015
            self.addCleanup(mem_server.stop_server)
1005
1016
            self.permit_url(mem_server.get_url())
1006
 
            self.backing_transport = get_transport(mem_server.get_url())
 
1017
            self.backing_transport = transport.get_transport(
 
1018
                mem_server.get_url())
1007
1019
        else:
1008
1020
            self.backing_transport = backing_transport
1009
1021
        if readonly:
1010
1022
            self.real_backing_transport = self.backing_transport
1011
 
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
 
1023
            self.backing_transport = transport.get_transport(
 
1024
                "readonly+" + self.backing_transport.abspath('.'))
1012
1025
        self.server = server.SmartTCPServer(self.backing_transport)
 
1026
        self.server.start_server('127.0.0.1', 0)
1013
1027
        self.server.start_background_thread('-' + self.id())
1014
1028
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1015
 
        self.addCleanup(self.tearDownServer)
 
1029
        self.addCleanup(self.stop_server)
1016
1030
        self.permit_url(self.server.get_url())
1017
1031
 
1018
 
    def tearDownServer(self):
 
1032
    def stop_server(self):
 
1033
        """Disconnect the client and stop the server.
 
1034
 
 
1035
        This must be re-entrant as some tests will call it explicitly in
 
1036
        addition to the normal cleanup.
 
1037
        """
1019
1038
        if getattr(self, 'transport', None):
1020
1039
            self.transport.disconnect()
1021
1040
            del self.transport
1026
1045
 
1027
1046
class TestServerSocketUsage(SmartTCPTests):
1028
1047
 
1029
 
    def test_server_setup_teardown(self):
1030
 
        """It should be safe to teardown the server with no requests."""
1031
 
        self.setUpServer()
1032
 
        server = self.server
1033
 
        transport = remote.RemoteTCPTransport(self.server.get_url())
1034
 
        self.tearDownServer()
1035
 
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
1048
    def test_server_start_stop(self):
 
1049
        """It should be safe to stop the server with no requests."""
 
1050
        self.start_server()
 
1051
        t = remote.RemoteTCPTransport(self.server.get_url())
 
1052
        self.stop_server()
 
1053
        self.assertRaises(errors.ConnectionError, t.has, '.')
1036
1054
 
1037
1055
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
1038
1056
        """The server should close its listening socket when it's stopped."""
1039
 
        self.setUpServer()
1040
 
        server = self.server
 
1057
        self.start_server()
 
1058
        server_url = self.server.get_url()
1041
1059
        self.transport.has('.')
1042
 
        self.tearDownServer()
 
1060
        self.stop_server()
1043
1061
        # if the listening socket has closed, we should get a BADFD error
1044
1062
        # when connecting, rather than a hang.
1045
 
        transport = remote.RemoteTCPTransport(server.get_url())
1046
 
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
1063
        t = remote.RemoteTCPTransport(server_url)
 
1064
        self.assertRaises(errors.ConnectionError, t.has, '.')
1047
1065
 
1048
1066
 
1049
1067
class WritableEndToEndTests(SmartTCPTests):
1051
1069
 
1052
1070
    def setUp(self):
1053
1071
        super(WritableEndToEndTests, self).setUp()
1054
 
        self.setUpServer()
 
1072
        self.start_server()
1055
1073
 
1056
1074
    def test_start_tcp_server(self):
1057
1075
        url = self.server.get_url()
1130
1148
    def test_mkdir_error_readonly(self):
1131
1149
        """TransportNotPossible should be preserved from the backing transport."""
1132
1150
        self._captureVar('BZR_NO_SMART_VFS', None)
1133
 
        self.setUpServer(readonly=True)
 
1151
        self.start_server(readonly=True)
1134
1152
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1135
1153
            'foo')
1136
1154
 
1146
1164
        self.hook_calls = []
1147
1165
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1148
1166
            self.capture_server_call, None)
1149
 
        self.setUpServer()
 
1167
        self.start_server()
1150
1168
        # at this point, the server will be starting a thread up.
1151
1169
        # there is no indicator at the moment, so bodge it by doing a request.
1152
1170
        self.transport.has('.')
1160
1178
        self.hook_calls = []
1161
1179
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1162
1180
            self.capture_server_call, None)
1163
 
        self.setUpServer(backing_transport=get_transport("."))
 
1181
        self.start_server(backing_transport=transport.get_transport("."))
1164
1182
        # at this point, the server will be starting a thread up.
1165
1183
        # there is no indicator at the moment, so bodge it by doing a request.
1166
1184
        self.transport.has('.')
1176
1194
        self.hook_calls = []
1177
1195
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1178
1196
            self.capture_server_call, None)
1179
 
        self.setUpServer()
 
1197
        self.start_server()
1180
1198
        result = [([self.backing_transport.base], self.transport.base)]
1181
1199
        # check the stopping message isn't emitted up front.
1182
1200
        self.assertEqual([], self.hook_calls)
1184
1202
        self.transport.has('.')
1185
1203
        self.assertEqual([], self.hook_calls)
1186
1204
        # clean up the server
1187
 
        self.tearDownServer()
 
1205
        self.stop_server()
1188
1206
        # now it should have fired.
1189
1207
        self.assertEqual(result, self.hook_calls)
1190
1208
 
1193
1211
        self.hook_calls = []
1194
1212
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1195
1213
            self.capture_server_call, None)
1196
 
        self.setUpServer(backing_transport=get_transport("."))
 
1214
        self.start_server(backing_transport=transport.get_transport("."))
1197
1215
        result = [(
1198
1216
            [self.backing_transport.base, self.backing_transport.external_url()]
1199
1217
            , self.transport.base)]
1203
1221
        self.transport.has('.')
1204
1222
        self.assertEqual([], self.hook_calls)
1205
1223
        # clean up the server
1206
 
        self.tearDownServer()
 
1224
        self.stop_server()
1207
1225
        # now it should have fired.
1208
1226
        self.assertEqual(result, self.hook_calls)
1209
1227
 
1338
1356
class RemoteTransportRegistration(tests.TestCase):
1339
1357
 
1340
1358
    def test_registration(self):
1341
 
        t = get_transport('bzr+ssh://example.com/path')
 
1359
        t = transport.get_transport('bzr+ssh://example.com/path')
1342
1360
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1343
1361
        self.assertEqual('example.com', t._host)
1344
1362
 
1345
1363
    def test_bzr_https(self):
1346
1364
        # https://bugs.launchpad.net/bzr/+bug/128456
1347
 
        t = get_transport('bzr+https://example.com/path')
 
1365
        t = transport.get_transport('bzr+https://example.com/path')
1348
1366
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1349
1367
        self.assertStartsWith(
1350
1368
            t._http_transport.base,
2856
2874
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2857
2875
 
2858
2876
    def assertWriteCount(self, expected_count):
 
2877
        # self.writes can be quite large; don't show the whole thing
2859
2878
        self.assertEqual(
2860
2879
            expected_count, len(self.writes),
2861
 
            "Too many writes: %r" % (self.writes,))
 
2880
            "Too many writes: %d, expected %d" % (len(self.writes), expected_count))
2862
2881
 
2863
2882
    def test_send_error_writes_just_once(self):
2864
2883
        """An error response is written to the medium all at once."""
2887
2906
        response = _mod_request.SuccessfulSmartServerResponse(
2888
2907
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2889
2908
        self.responder.send_response(response)
2890
 
        # We will write just once, despite the multiple chunks, due to
2891
 
        # buffering.
2892
 
        self.assertWriteCount(1)
2893
 
 
2894
 
    def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
2895
 
        """When there are many bytes (>1MB), multiple writes will occur rather
2896
 
        than buffering indefinitely.
2897
 
        """
2898
 
        # Construct a response with stream with ~1.5MB in it. This should
2899
 
        # trigger 2 writes, but not 3
2900
 
        onekib = '12345678' * 128
2901
 
        body_stream = [onekib] * (1024 + 512)
2902
 
        response = _mod_request.SuccessfulSmartServerResponse(
2903
 
            ('arg', 'arg'), body_stream=body_stream)
2904
 
        self.responder.send_response(response)
2905
 
        # The write buffer is flushed every 100 buffered writes, so we expect 2
2906
 
        # actual writes.
2907
 
        self.assertWriteCount(2)
 
2909
        # Per the discussion in bug 590638 we flush once after the header and
 
2910
        # then once after each chunk
 
2911
        self.assertWriteCount(3)
2908
2912
 
2909
2913
 
2910
2914
class TestSmartClientUnicode(tests.TestCase):
3517
3521
 
3518
3522
    def test_smart_http_medium_request_accept_bytes(self):
3519
3523
        medium = FakeHTTPMedium()
3520
 
        request = SmartClientHTTPMediumRequest(medium)
 
3524
        request = http.SmartClientHTTPMediumRequest(medium)
3521
3525
        request.accept_bytes('abc')
3522
3526
        request.accept_bytes('def')
3523
3527
        self.assertEqual(None, medium.written_request)