/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Andrew Bennetts
  • Date: 2010-01-12 03:53:21 UTC
  • mfrom: (4948 +trunk)
  • mto: This revision was merged to the branch mainline in revision 4964.
  • Revision ID: andrew.bennetts@canonical.com-20100112035321-hofpz5p10224ryj3
Merge lp:bzr, resolving conflicts.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Tests for smart transport"""
18
18
 
68
68
 
69
69
    def __init__(self, vendor):
70
70
        self.vendor = vendor
71
 
    
 
71
 
72
72
    def close(self):
73
73
        self.vendor.calls.append(('close', ))
74
 
        
 
74
 
75
75
    def get_filelike_channels(self):
76
76
        return self.vendor.read_from, self.vendor.write_to
77
77
 
78
78
 
79
79
class _InvalidHostnameFeature(tests.Feature):
80
80
    """Does 'non_existent.invalid' fail to resolve?
81
 
    
 
81
 
82
82
    RFC 2606 states that .invalid is reserved for invalid domain names, and
83
83
    also underscores are not a valid character in domain names.  Despite this,
84
84
    it's possible a badly misconfigured name server might decide to always
132
132
        t = threading.Thread(target=_receive_bytes_on_server)
133
133
        t.start()
134
134
        return t
135
 
    
 
135
 
136
136
    def test_construct_smart_simple_pipes_client_medium(self):
137
137
        # the SimplePipes client medium takes two pipes:
138
138
        # readable pipe, writeable pipe.
139
139
        # Constructing one should just save these and do nothing.
140
140
        # We test this by passing in None.
141
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
142
 
        
 
142
 
143
143
    def test_simple_pipes_client_request_type(self):
144
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
145
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
148
148
 
149
149
    def test_simple_pipes_client_get_concurrent_requests(self):
150
150
        # the simple_pipes client does not support pipelined requests:
151
 
        # but it does support serial requests: we construct one after 
 
151
        # but it does support serial requests: we construct one after
152
152
        # another is finished. This is a smoke test testing the integration
153
153
        # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
154
154
        # classes - as the sibling classes share this logic, they do not have
170
170
            None, output, 'base')
171
171
        client_medium._accept_bytes('abc')
172
172
        self.assertEqual('abc', output.getvalue())
173
 
    
 
173
 
174
174
    def test_simple_pipes_client_disconnect_does_nothing(self):
175
175
        # calling disconnect does nothing.
176
176
        input = StringIO()
197
197
        self.assertFalse(input.closed)
198
198
        self.assertFalse(output.closed)
199
199
        self.assertEqual('abcabc', output.getvalue())
200
 
    
 
200
 
201
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
202
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
203
203
        # does nothing.
212
212
        self.assertEqual('abc', client_medium.read_bytes(3))
213
213
        client_medium.disconnect()
214
214
        self.assertEqual('def', client_medium.read_bytes(3))
215
 
        
 
215
 
216
216
    def test_simple_pipes_client_supports__flush(self):
217
 
        # invoking _flush on a SimplePipesClient should flush the output 
 
217
        # invoking _flush on a SimplePipesClient should flush the output
218
218
        # pipe. We test this by creating an output pipe that records
219
219
        # flush calls made to it.
220
220
        from StringIO import StringIO # get regular StringIO
262
262
            'a hostname', 'a port',
263
263
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
264
264
            vendor.calls)
265
 
    
266
 
    def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
267
 
        # The only thing that initiates a connection from the medium is giving
268
 
        # it bytes.
269
 
        output = StringIO()
270
 
        vendor = StringIOSSHVendor(StringIO(), output)
271
 
        orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
272
 
        def cleanup_environ():
273
 
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
274
 
        self.addCleanup(cleanup_environ)
275
 
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
276
 
        client_medium = self.callDeprecated(
277
 
            ['bzr_remote_path is required as of bzr 0.92'],
278
 
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
279
 
            'a password', 'base', vendor)
280
 
        client_medium._accept_bytes('abc')
281
 
        self.assertEqual('abc', output.getvalue())
282
 
        self.assertEqual([('connect_ssh', 'a username', 'a password',
283
 
            'a hostname', 'a port',
284
 
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
285
 
            vendor.calls)
286
 
    
 
265
 
287
266
    def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
288
267
        # The only thing that initiates a connection from the medium is giving
289
268
        # it bytes.
350
329
            ('close', ),
351
330
            ],
352
331
            vendor.calls)
353
 
    
 
332
 
354
333
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
355
334
        # Doing a disconnect on a new (and thus unconnected) SSH medium
356
335
        # does not fail.  It's ok to disconnect an unconnected medium.
369
348
                          1)
370
349
 
371
350
    def test_ssh_client_supports__flush(self):
372
 
        # invoking _flush on a SSHClientMedium should flush the output 
 
351
        # invoking _flush on a SSHClientMedium should flush the output
373
352
        # pipe. We test this by creating an output pipe that records
374
353
        # flush calls made to it.
375
354
        from StringIO import StringIO # get regular StringIO
387
366
        client_medium._flush()
388
367
        client_medium.disconnect()
389
368
        self.assertEqual(['flush'], flush_calls)
390
 
        
 
369
 
391
370
    def test_construct_smart_tcp_client_medium(self):
392
371
        # the TCP client medium takes a host and a port.  Constructing it won't
393
372
        # connect to anything.
408
387
        t.join()
409
388
        sock.close()
410
389
        self.assertEqual(['abc'], bytes)
411
 
    
 
390
 
412
391
    def test_tcp_client_disconnect_does_so(self):
413
392
        # calling disconnect on the client terminates the connection.
414
393
        # we test this by forcing a short read during a socket.MSG_WAITALL
425
404
        # really did disconnect.
426
405
        medium.disconnect()
427
406
 
428
 
    
 
407
 
429
408
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
430
409
        # Doing a disconnect on a new (and thus unconnected) TCP medium
431
410
        # does not fail.  It's ok to disconnect an unconnected medium.
468
447
 
469
448
class TestSmartClientStreamMediumRequest(tests.TestCase):
470
449
    """Tests the for SmartClientStreamMediumRequest.
471
 
    
472
 
    SmartClientStreamMediumRequest is a helper for the three stream based 
 
450
 
 
451
    SmartClientStreamMediumRequest is a helper for the three stream based
473
452
    mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
474
453
    those three mediums implement the interface it expects.
475
454
    """
476
455
 
477
456
    def test_accept_bytes_after_finished_writing_errors(self):
478
 
        # calling accept_bytes after calling finished_writing raises 
 
457
        # calling accept_bytes after calling finished_writing raises
479
458
        # WritingCompleted to prevent bad assumptions on stream environments
480
459
        # breaking the needs of message-based environments.
481
460
        output = StringIO()
537
516
            None, None, 'base')
538
517
        request = medium.SmartClientStreamMediumRequest(client_medium)
539
518
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
540
 
        
 
519
 
541
520
    def test_read_bytes(self):
542
521
        # read bytes should invoke _read_bytes on the stream medium.
543
522
        # we test this by using the SimplePipes medium - the most trivial one
544
 
        # and checking that the data is supplied. Its possible that a 
 
523
        # and checking that the data is supplied. Its possible that a
545
524
        # faulty implementation could poke at the pipe variables them selves,
546
525
        # but we trust that this will be caught as it will break the integration
547
526
        # smoke tests.
566
545
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
567
546
 
568
547
    def test_read_bytes_after_finished_reading_errors(self):
569
 
        # calling read_bytes after calling finished_reading raises 
 
548
        # calling read_bytes after calling finished_reading raises
570
549
        # ReadingCompleted to prevent bad assumptions on stream environments
571
550
        # breaking the needs of message-based environments.
572
551
        output = StringIO()
604
583
 
605
584
 
606
585
class SampleRequest(object):
607
 
    
 
586
 
608
587
    def __init__(self, expected_bytes):
609
588
        self.accepted_bytes = ''
610
589
        self._finished_reading = False
632
611
 
633
612
    def portable_socket_pair(self):
634
613
        """Return a pair of TCP sockets connected to each other.
635
 
        
 
614
 
636
615
        Unlike socket.socketpair, this should work on Windows.
637
616
        """
638
617
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
643
622
        server_sock, addr = listen_sock.accept()
644
623
        listen_sock.close()
645
624
        return server_sock, client_sock
646
 
    
 
625
 
647
626
    def test_smart_query_version(self):
648
627
        """Feed a canned query version to a server"""
649
628
        # wire-to-wire, using the whole stack
678
657
        # wire-to-wire, using the whole stack, with a UTF-8 filename.
679
658
        transport = memory.MemoryTransport('memory:///')
680
659
        utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
 
660
        # VFS requests use filenames, not raw UTF-8.
 
661
        hpss_path = urlutils.escape(utf8_filename)
681
662
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
682
 
        to_server = StringIO('get\001' + utf8_filename + '\n')
 
663
        to_server = StringIO('get\001' + hpss_path + '\n')
683
664
        from_server = StringIO()
684
665
        server = medium.SmartServerPipeStreamMedium(
685
666
            to_server, from_server, transport)
723
704
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
724
705
        server._serve_one_request(SampleRequest('x'))
725
706
        self.assertTrue(server.finished)
726
 
        
 
707
 
727
708
    def test_socket_stream_shutdown_detection(self):
728
709
        server_sock, client_sock = self.portable_socket_pair()
729
710
        client_sock.close()
731
712
            server_sock, None)
732
713
        server._serve_one_request(SampleRequest('x'))
733
714
        self.assertTrue(server.finished)
734
 
        
 
715
 
735
716
    def test_socket_stream_incomplete_request(self):
736
717
        """The medium should still construct the right protocol version even if
737
718
        the initial read only reads part of the request.
753
734
        client_sock.sendall(rest_of_request_bytes)
754
735
        server._serve_one_request(server_protocol)
755
736
        server_sock.close()
756
 
        self.assertEqual(expected_response, client_sock.recv(50),
 
737
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
757
738
                         "Not a version 2 response to 'hello' request.")
758
739
        self.assertEqual('', client_sock.recv(1))
759
740
 
797
778
    def test_pipe_like_stream_with_two_requests(self):
798
779
        # If two requests are read in one go, then two calls to
799
780
        # _serve_one_request should still process both of them as if they had
800
 
        # been received seperately.
 
781
        # been received separately.
801
782
        sample_request_bytes = 'command\n'
802
783
        to_server = StringIO(sample_request_bytes * 2)
803
784
        from_server = StringIO()
815
796
        self.assertEqual('', from_server.getvalue())
816
797
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
817
798
        self.assertFalse(server.finished)
818
 
        
 
799
 
819
800
    def test_socket_stream_with_two_requests(self):
820
801
        # If two requests are read in one go, then two calls to
821
802
        # _serve_one_request should still process both of them as if they had
822
 
        # been received seperately.
 
803
        # been received separately.
823
804
        sample_request_bytes = 'command\n'
824
805
        server_sock, client_sock = self.portable_socket_pair()
825
806
        server = medium.SmartServerSocketStreamMedium(
856
837
        self.assertEqual('', from_server.getvalue())
857
838
        self.assertTrue(self.closed)
858
839
        self.assertTrue(server.finished)
859
 
        
 
840
 
860
841
    def test_socket_stream_error_handling(self):
861
842
        server_sock, client_sock = self.portable_socket_pair()
862
843
        server = medium.SmartServerSocketStreamMedium(
867
848
        # closed.
868
849
        self.assertEqual('', client_sock.recv(1))
869
850
        self.assertTrue(server.finished)
870
 
        
 
851
 
871
852
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
872
853
        to_server = StringIO('')
873
854
        from_server = StringIO()
918
899
        # Any empty request (i.e. no bytes) is detected as protocol version one.
919
900
        server_protocol = self.build_protocol_pipe_like('')
920
901
        self.assertProtocolOne(server_protocol)
921
 
        
 
902
 
922
903
    def test_socket_like_build_protocol_empty_bytes(self):
923
904
        # Any empty request (i.e. no bytes) is detected as protocol version one.
924
905
        server_protocol = self.build_protocol_socket('')
959
940
        self.assertEqual(
960
941
            protocol.build_server_protocol_three, protocol_factory)
961
942
        self.assertEqual('extra bytes', remainder)
962
 
        
 
943
 
963
944
    def test_version_two(self):
964
945
        result = medium._get_protocol_factory_for_bytes(
965
946
            'bzr request 2\nextra bytes')
967
948
        self.assertEqual(
968
949
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
969
950
        self.assertEqual('extra bytes', remainder)
970
 
        
 
951
 
971
952
    def test_version_one(self):
972
953
        """Version one requests have no version markers."""
973
954
        result = medium._get_protocol_factory_for_bytes('anything\n')
975
956
        self.assertEqual(
976
957
            protocol.SmartServerRequestProtocolOne, protocol_factory)
977
958
        self.assertEqual('anything\n', remainder)
978
 
        
 
959
 
979
960
 
980
961
class TestSmartTCPServer(tests.TestCase):
981
962
 
1006
987
    All of these tests are run with a server running on another thread serving
1007
988
    a MemoryTransport, and a connection to it already open.
1008
989
 
1009
 
    the server is obtained by calling self.setUpServer(readonly=False).
 
990
    the server is obtained by calling self.start_server(readonly=False).
1010
991
    """
1011
992
 
1012
 
    def setUpServer(self, readonly=False, backing_transport=None):
 
993
    def start_server(self, readonly=False, backing_transport=None):
1013
994
        """Setup the server.
1014
995
 
1015
996
        :param readonly: Create a readonly server.
1016
997
        """
 
998
        # NB: Tests using this fall into two categories: tests of the server,
 
999
        # tests wanting a server. The latter should be updated to use
 
1000
        # self.vfs_transport_factory etc.
1017
1001
        if not backing_transport:
1018
 
            self.backing_transport = memory.MemoryTransport()
 
1002
            mem_server = memory.MemoryServer()
 
1003
            mem_server.start_server()
 
1004
            self.addCleanup(mem_server.stop_server)
 
1005
            self.permit_url(mem_server.get_url())
 
1006
            self.backing_transport = get_transport(mem_server.get_url())
1019
1007
        else:
1020
1008
            self.backing_transport = backing_transport
1021
1009
        if readonly:
1025
1013
        self.server.start_background_thread('-' + self.id())
1026
1014
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1027
1015
        self.addCleanup(self.tearDownServer)
 
1016
        self.permit_url(self.server.get_url())
1028
1017
 
1029
1018
    def tearDownServer(self):
1030
1019
        if getattr(self, 'transport', None):
1032
1021
            del self.transport
1033
1022
        if getattr(self, 'server', None):
1034
1023
            self.server.stop_background_thread()
 
1024
            # XXX: why not .stop_server() -- mbp 20100106
1035
1025
            del self.server
1036
1026
 
1037
1027
 
1039
1029
 
1040
1030
    def test_server_setup_teardown(self):
1041
1031
        """It should be safe to teardown the server with no requests."""
1042
 
        self.setUpServer()
 
1032
        self.start_server()
1043
1033
        server = self.server
1044
1034
        transport = remote.RemoteTCPTransport(self.server.get_url())
1045
1035
        self.tearDownServer()
1047
1037
 
1048
1038
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
1049
1039
        """The server should close its listening socket when it's stopped."""
1050
 
        self.setUpServer()
 
1040
        self.start_server()
1051
1041
        server = self.server
1052
1042
        self.transport.has('.')
1053
1043
        self.tearDownServer()
1062
1052
 
1063
1053
    def setUp(self):
1064
1054
        super(WritableEndToEndTests, self).setUp()
1065
 
        self.setUpServer()
 
1055
        self.start_server()
1066
1056
 
1067
1057
    def test_start_tcp_server(self):
1068
1058
        url = self.server.get_url()
1090
1080
        self._captureVar('BZR_NO_SMART_VFS', None)
1091
1081
        err = self.assertRaises(
1092
1082
            errors.NoSuchFile, self.transport.get, 'not%20a%20file')
1093
 
        self.assertEqual('not%20a%20file', err.path)
 
1083
        self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file'])
1094
1084
 
1095
1085
    def test_simple_clone_conn(self):
1096
1086
        """Test that cloning reuses the same connection."""
1141
1131
    def test_mkdir_error_readonly(self):
1142
1132
        """TransportNotPossible should be preserved from the backing transport."""
1143
1133
        self._captureVar('BZR_NO_SMART_VFS', None)
1144
 
        self.setUpServer(readonly=True)
 
1134
        self.start_server(readonly=True)
1145
1135
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1146
1136
            'foo')
1147
1137
 
1157
1147
        self.hook_calls = []
1158
1148
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1159
1149
            self.capture_server_call, None)
1160
 
        self.setUpServer()
 
1150
        self.start_server()
1161
1151
        # at this point, the server will be starting a thread up.
1162
1152
        # there is no indicator at the moment, so bodge it by doing a request.
1163
1153
        self.transport.has('.')
1171
1161
        self.hook_calls = []
1172
1162
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1173
1163
            self.capture_server_call, None)
1174
 
        self.setUpServer(backing_transport=get_transport("."))
 
1164
        self.start_server(backing_transport=get_transport("."))
1175
1165
        # at this point, the server will be starting a thread up.
1176
1166
        # there is no indicator at the moment, so bodge it by doing a request.
1177
1167
        self.transport.has('.')
1187
1177
        self.hook_calls = []
1188
1178
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1189
1179
            self.capture_server_call, None)
1190
 
        self.setUpServer()
 
1180
        self.start_server()
1191
1181
        result = [([self.backing_transport.base], self.transport.base)]
1192
1182
        # check the stopping message isn't emitted up front.
1193
1183
        self.assertEqual([], self.hook_calls)
1204
1194
        self.hook_calls = []
1205
1195
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1206
1196
            self.capture_server_call, None)
1207
 
        self.setUpServer(backing_transport=get_transport("."))
 
1197
        self.start_server(backing_transport=get_transport("."))
1208
1198
        result = [(
1209
1199
            [self.backing_transport.base, self.backing_transport.external_url()]
1210
1200
            , self.transport.base)]
1229
1219
    Note: these tests are rudimentary versions of the command object tests in
1230
1220
    test_smart.py.
1231
1221
    """
1232
 
        
 
1222
 
1233
1223
    def test_hello(self):
1234
1224
        cmd = _mod_request.HelloRequest(None, '/')
1235
1225
        response = cmd.execute()
1236
1226
        self.assertEqual(('ok', '2'), response.args)
1237
1227
        self.assertEqual(None, response.body)
1238
 
        
 
1228
 
1239
1229
    def test_get_bundle(self):
1240
1230
        from bzrlib.bundle import serializer
1241
1231
        wt = self.make_branch_and_tree('.')
1242
1232
        self.build_tree_contents([('hello', 'hello world')])
1243
1233
        wt.add('hello')
1244
1234
        rev_id = wt.commit('add hello')
1245
 
        
 
1235
 
1246
1236
        cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1247
1237
        response = cmd.execute('.', rev_id)
1248
1238
        bundle = serializer.read_bundle(StringIO(response.body))
1269
1259
 
1270
1260
    def test_hello(self):
1271
1261
        handler = self.build_handler(None)
1272
 
        handler.dispatch_command('hello', ())
 
1262
        handler.args_received(('hello',))
1273
1263
        self.assertEqual(('ok', '2'), handler.response.args)
1274
1264
        self.assertEqual(None, handler.response.body)
1275
 
        
 
1265
 
1276
1266
    def test_disable_vfs_handler_classes_via_environment(self):
1277
1267
        # VFS handler classes will raise an error from "execute" if
1278
1268
        # BZR_NO_SMART_VFS is set.
1289
1279
        """The response for a read-only error is ('ReadOnlyError')."""
1290
1280
        handler = self.build_handler(self.get_readonly_transport())
1291
1281
        # send a mkdir for foo, with no explicit mode - should fail.
1292
 
        handler.dispatch_command('mkdir', ('foo', ''))
 
1282
        handler.args_received(('mkdir', 'foo', ''))
1293
1283
        # and the failure should be an explicit ReadOnlyError
1294
1284
        self.assertEqual(("ReadOnlyError", ), handler.response.args)
1295
1285
        # XXX: TODO: test that other TransportNotPossible errors are
1300
1290
    def test_hello_has_finished_body_on_dispatch(self):
1301
1291
        """The 'hello' command should set finished_reading."""
1302
1292
        handler = self.build_handler(None)
1303
 
        handler.dispatch_command('hello', ())
 
1293
        handler.args_received(('hello',))
1304
1294
        self.assertTrue(handler.finished_reading)
1305
1295
        self.assertNotEqual(None, handler.response)
1306
1296
 
1307
1297
    def test_put_bytes_non_atomic(self):
1308
1298
        """'put_...' should set finished_reading after reading the bytes."""
1309
1299
        handler = self.build_handler(self.get_transport())
1310
 
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
 
1300
        handler.args_received(('put_non_atomic', 'a-file', '', 'F', ''))
1311
1301
        self.assertFalse(handler.finished_reading)
1312
1302
        handler.accept_body('1234')
1313
1303
        self.assertFalse(handler.finished_reading)
1316
1306
        self.assertTrue(handler.finished_reading)
1317
1307
        self.assertEqual(('ok', ), handler.response.args)
1318
1308
        self.assertEqual(None, handler.response.body)
1319
 
        
 
1309
 
1320
1310
    def test_readv_accept_body(self):
1321
1311
        """'readv' should set finished_reading after reading offsets."""
1322
1312
        self.build_tree(['a-file'])
1323
1313
        handler = self.build_handler(self.get_readonly_transport())
1324
 
        handler.dispatch_command('readv', ('a-file', ))
 
1314
        handler.args_received(('readv', 'a-file'))
1325
1315
        self.assertFalse(handler.finished_reading)
1326
1316
        handler.accept_body('2,')
1327
1317
        self.assertFalse(handler.finished_reading)
1336
1326
        """'readv' when a short read occurs sets the response appropriately."""
1337
1327
        self.build_tree(['a-file'])
1338
1328
        handler = self.build_handler(self.get_readonly_transport())
1339
 
        handler.dispatch_command('readv', ('a-file', ))
 
1329
        handler.args_received(('readv', 'a-file'))
1340
1330
        # read beyond the end of the file.
1341
1331
        handler.accept_body('100,1')
1342
1332
        handler.end_of_body()
1363
1353
 
1364
1354
 
1365
1355
class TestRemoteTransport(tests.TestCase):
1366
 
        
 
1356
 
1367
1357
    def test_use_connection_factory(self):
1368
1358
        # We want to be able to pass a client as a parameter to RemoteTransport.
1369
1359
        input = StringIO('ok\n3\nbardone\n')
1468
1458
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1469
1459
        requester):
1470
1460
        """Check that smart (de)serialises offsets as expected.
1471
 
        
 
1461
 
1472
1462
        We check both serialisation and deserialisation at the same time
1473
1463
        to ensure that the round tripping cannot skew: both directions should
1474
1464
        be as expected.
1475
 
        
 
1465
 
1476
1466
        :param expected_offsets: a readv offset list.
1477
1467
        :param expected_seralised: an expected serial form of the offsets.
1478
1468
        """
1528
1518
        ex = self.assertRaises(errors.ConnectionReset,
1529
1519
            response_handler.read_response_tuple)
1530
1520
        self.assertEqual("Connection closed: "
1531
 
            "please check connectivity and permissions "
1532
 
            "(and try -Dhpss if further diagnosis is required)", str(ex))
 
1521
            "Unexpected end of message. Please check connectivity "
 
1522
            "and permissions, and report a bug if problems persist. ",
 
1523
            str(ex))
1533
1524
 
1534
1525
    def test_server_offset_serialisation(self):
1535
1526
        """The Smart protocol serialises offsets as a comma and \n string.
1654
1645
 
1655
1646
    def test_query_version(self):
1656
1647
        """query_version on a SmartClientProtocolOne should return a number.
1657
 
        
 
1648
 
1658
1649
        The protocol provides the query_version because the domain level clients
1659
1650
        may all need to be able to probe for capabilities.
1660
1651
        """
1661
1652
        # What we really want to test here is that SmartClientProtocolOne calls
1662
1653
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1663
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1654
        # response of tuple-encoded (ok, 1).  Also, separately we should test
1664
1655
        # the error if the response is a non-understood version.
1665
1656
        input = StringIO('ok\x012\n')
1666
1657
        output = StringIO()
1925
1916
 
1926
1917
    def test_query_version(self):
1927
1918
        """query_version on a SmartClientProtocolTwo should return a number.
1928
 
        
 
1919
 
1929
1920
        The protocol provides the query_version because the domain level clients
1930
1921
        may all need to be able to probe for capabilities.
1931
1922
        """
1932
1923
        # What we really want to test here is that SmartClientProtocolTwo calls
1933
1924
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1934
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1925
        # response of tuple-encoded (ok, 1).  Also, separately we should test
1935
1926
        # the error if the response is a non-understood version.
1936
1927
        input = StringIO(self.response_marker + 'success\nok\x012\n')
1937
1928
        output = StringIO()
2270
2261
        self.assertEqual(4, smart_protocol.next_read_size())
2271
2262
 
2272
2263
 
2273
 
class NoOpRequest(_mod_request.SmartServerRequest):
2274
 
 
2275
 
    def do(self):
2276
 
        return _mod_request.SuccessfulSmartServerResponse(())
2277
 
 
2278
 
dummy_registry = {'ARG': NoOpRequest}
2279
 
 
2280
 
 
2281
2264
class LoggingMessageHandler(object):
2282
2265
 
2283
2266
    def __init__(self):
2325
2308
        self.assertEqual(0, smart_protocol.next_read_size())
2326
2309
        self.assertEqual('', smart_protocol.unused_data)
2327
2310
 
 
2311
    def test_repeated_excess(self):
 
2312
        """Repeated calls to accept_bytes after the message end has been parsed
 
2313
        accumlates the bytes in the unused_data attribute.
 
2314
        """
 
2315
        output = StringIO()
 
2316
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2317
        end = 'e'
 
2318
        request_bytes = headers + end
 
2319
        smart_protocol = self.server_protocol_class(LoggingMessageHandler())
 
2320
        smart_protocol.accept_bytes(request_bytes)
 
2321
        self.assertEqual('', smart_protocol.unused_data)
 
2322
        smart_protocol.accept_bytes('aaa')
 
2323
        self.assertEqual('aaa', smart_protocol.unused_data)
 
2324
        smart_protocol.accept_bytes('bbb')
 
2325
        self.assertEqual('aaabbb', smart_protocol.unused_data)
 
2326
        self.assertEqual(0, smart_protocol.next_read_size())
 
2327
 
2328
2328
    def make_protocol_expecting_message_part(self):
2329
2329
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
2330
2330
        message_handler = LoggingMessageHandler()
2393
2393
        return response_handler
2394
2394
 
2395
2395
    def test_interrupted_by_error(self):
2396
 
        interrupted_body_stream = (
2397
 
            'oS' # successful response
2398
 
            's\0\0\0\x02le' # empty args
2399
 
            'b\0\0\0\x09chunk one' # first chunk
2400
 
            'b\0\0\0\x09chunk two' # second chunk
2401
 
            'oE' # error flag
2402
 
            's\0\0\0\x0el5:error3:abce' # bencoded error
2403
 
            'e' # message end
2404
 
            )
2405
2396
        response_handler = self.make_response_handler(interrupted_body_stream)
2406
2397
        stream = response_handler.read_streamed_body()
2407
 
        self.assertEqual('chunk one', stream.next())
2408
 
        self.assertEqual('chunk two', stream.next())
 
2398
        self.assertEqual('aaa', stream.next())
 
2399
        self.assertEqual('bbb', stream.next())
2409
2400
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2410
 
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2401
        self.assertEqual(('error', 'Boom!'), exc.error_tuple)
2411
2402
 
2412
2403
    def test_interrupted_by_connection_lost(self):
2413
2404
        interrupted_body_stream = (
2562
2553
        self.calls.append(('end_received',))
2563
2554
        self.finished_reading = True
2564
2555
 
2565
 
    def dispatch_command(self, cmd, args):
2566
 
        self.calls.append(('dispatch_command', cmd, args))
 
2556
    def args_received(self, args):
 
2557
        self.calls.append(('args_received', args))
2567
2558
 
2568
2559
    def accept_body(self, bytes):
2569
2560
        self.calls.append(('accept_body', bytes))
2765
2756
    def test_call_with_body_stream_error(self):
2766
2757
        """call_with_body_stream will abort the streamed body with an
2767
2758
        error if the stream raises an error during iteration.
2768
 
        
 
2759
 
2769
2760
        The resulting request will still be a complete message.
2770
2761
        """
2771
2762
        requester, output = self.make_client_encoder_and_output()
2774
2765
            yield 'aaa'
2775
2766
            yield 'bbb'
2776
2767
            raise Exception('Boom!')
2777
 
        requester.call_with_body_stream(('one arg',), stream_that_fails())
 
2768
        self.assertRaises(Exception, requester.call_with_body_stream,
 
2769
            ('one arg',), stream_that_fails())
2778
2770
        self.assertEquals(
2779
2771
            'bzr message 3 (bzr 1.6)\n' # protocol version
2780
2772
            '\x00\x00\x00\x02de' # headers
2803
2795
        self.calls.append('finished_writing')
2804
2796
 
2805
2797
 
 
2798
interrupted_body_stream = (
 
2799
    'oS' # status flag (success)
 
2800
    's\x00\x00\x00\x08l4:argse' # args struct ('args,')
 
2801
    'b\x00\x00\x00\x03aaa' # body part ('aaa')
 
2802
    'b\x00\x00\x00\x03bbb' # body part ('bbb')
 
2803
    'oE' # status flag (error)
 
2804
    's\x00\x00\x00\x10l5:error5:Boom!e' # err struct ('error', 'Boom!')
 
2805
    'e' # EOM
 
2806
    )
 
2807
 
 
2808
 
2806
2809
class TestResponseEncodingProtocolThree(tests.TestCase):
2807
2810
 
2808
2811
    def make_response_encoder(self):
2824
2827
            # end of message
2825
2828
            'e')
2826
2829
 
 
2830
    def test_send_broken_body_stream(self):
 
2831
        encoder, out_stream = self.make_response_encoder()
 
2832
        encoder._headers = {}
 
2833
        def stream_that_fails():
 
2834
            yield 'aaa'
 
2835
            yield 'bbb'
 
2836
            raise Exception('Boom!')
 
2837
        response = _mod_request.SuccessfulSmartServerResponse(
 
2838
            ('args',), body_stream=stream_that_fails())
 
2839
        encoder.send_response(response)
 
2840
        expected_response = (
 
2841
            'bzr message 3 (bzr 1.6)\n'  # protocol marker
 
2842
            '\x00\x00\x00\x02de' # headers dict (empty)
 
2843
            + interrupted_body_stream)
 
2844
        self.assertEqual(expected_response, out_stream.getvalue())
 
2845
 
2827
2846
 
2828
2847
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
2829
2848
    """Tests for buffering of responses.
2833
2852
    """
2834
2853
 
2835
2854
    def setUp(self):
 
2855
        tests.TestCase.setUp(self)
2836
2856
        self.writes = []
2837
2857
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2838
2858
 
2840
2860
        self.assertEqual(
2841
2861
            expected_count, len(self.writes),
2842
2862
            "Too many writes: %r" % (self.writes,))
2843
 
        
 
2863
 
2844
2864
    def test_send_error_writes_just_once(self):
2845
2865
        """An error response is written to the medium all at once."""
2846
2866
        self.responder.send_error(Exception('An exception string.'))
2862
2882
        self.responder.send_response(response)
2863
2883
        self.assertWriteCount(1)
2864
2884
 
2865
 
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
2866
 
        """A normal response with a stream body is written to the medium
2867
 
        writes to the medium once per chunk.
2868
 
        """
 
2885
    def test_send_response_with_body_stream_buffers_writes(self):
 
2886
        """A normal response with a stream body writes to the medium once."""
2869
2887
        # Construct a response with stream with 2 chunks in it.
2870
2888
        response = _mod_request.SuccessfulSmartServerResponse(
2871
2889
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2872
2890
        self.responder.send_response(response)
2873
 
        # We will write 3 times: exactly once for each chunk, plus a final
2874
 
        # write to end the response.
2875
 
        self.assertWriteCount(3)
 
2891
        # We will write just once, despite the multiple chunks, due to
 
2892
        # buffering.
 
2893
        self.assertWriteCount(1)
 
2894
 
 
2895
    def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
 
2896
        """When there are many bytes (>1MB), multiple writes will occur rather
 
2897
        than buffering indefinitely.
 
2898
        """
 
2899
        # Construct a response with stream with ~1.5MB in it. This should
 
2900
        # trigger 2 writes, but not 3
 
2901
        onekib = '12345678' * 128
 
2902
        body_stream = [onekib] * (1024 + 512)
 
2903
        response = _mod_request.SuccessfulSmartServerResponse(
 
2904
            ('arg', 'arg'), body_stream=body_stream)
 
2905
        self.responder.send_response(response)
 
2906
        self.assertWriteCount(2)
2876
2907
 
2877
2908
 
2878
2909
class TestSmartClientUnicode(tests.TestCase):
2915
2946
 
2916
2947
class MockMedium(medium.SmartClientMedium):
2917
2948
    """A mock medium that can be used to test _SmartClient.
2918
 
    
 
2949
 
2919
2950
    It can be given a series of requests to expect (and responses it should
2920
2951
    return for them).  It can also be told when the client is expected to
2921
2952
    disconnect a medium.  Expectations must be satisfied in the order they are
2933
2964
        super(MockMedium, self).__init__('dummy base')
2934
2965
        self._mock_request = _MockMediumRequest(self)
2935
2966
        self._expected_events = []
2936
 
        
 
2967
 
2937
2968
    def expect_request(self, request_bytes, response_bytes,
2938
2969
                       allow_partial_read=False):
2939
2970
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
2942
2973
        called to send the request.  Similarly, no assumption is made about how
2943
2974
        many times read_bytes/read_line are called by protocol code to read a
2944
2975
        response.  e.g.::
2945
 
        
 
2976
 
2946
2977
            request.accept_bytes('ab')
2947
2978
            request.accept_bytes('cd')
2948
2979
            request.finished_writing()
2949
2980
 
2950
2981
        and::
2951
 
        
 
2982
 
2952
2983
            request.accept_bytes('abcd')
2953
2984
            request.finished_writing()
2954
2985
 
3139
3170
    def test_first_response_is_error(self):
3140
3171
        """If the server replies with an error, then the version detection
3141
3172
        should be complete.
3142
 
        
 
3173
 
3143
3174
        This test is very similar to test_version_two_server, but catches a bug
3144
3175
        we had in the case where the first reply was an error response.
3145
3176
        """
3185
3216
 
3186
3217
class LengthPrefixedBodyDecoder(tests.TestCase):
3187
3218
 
3188
 
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
 
3219
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3189
3220
    # something similar to the ProtocolBase method.
3190
3221
 
3191
3222
    def test_construct(self):
3227
3258
        self.assertEqual(1, decoder.next_read_size())
3228
3259
        self.assertEqual('', decoder.read_pending_data())
3229
3260
        self.assertEqual('blarg', decoder.unused_data)
3230
 
        
 
3261
 
3231
3262
    def test_accept_bytes_all_at_once_with_excess(self):
3232
3263
        decoder = protocol.LengthPrefixedBodyDecoder()
3233
3264
        decoder.accept_bytes('1\nadone\nunused')
3252
3283
 
3253
3284
class TestChunkedBodyDecoder(tests.TestCase):
3254
3285
    """Tests for ChunkedBodyDecoder.
3255
 
    
 
3286
 
3256
3287
    This is the body decoder used for protocol version two.
3257
3288
    """
3258
3289
 
3284
3315
        self.assertTrue(decoder.finished_reading)
3285
3316
        self.assertEqual(chunk_content, decoder.read_next_chunk())
3286
3317
        self.assertEqual('', decoder.unused_data)
3287
 
        
 
3318
 
3288
3319
    def test_incomplete_chunk(self):
3289
3320
        """When there are less bytes in the chunk than declared by the length,
3290
3321
        then we haven't finished reading yet.
3522
3553
        # still work correctly.
3523
3554
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3524
3555
        new_transport = base_transport.clone('c')
3525
 
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
3556
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
3526
3557
        self.assertEqual(
3527
3558
            'c/',
3528
3559
            new_transport._client.remote_path_from_transport(new_transport))
3555
3586
        self.assertNotEquals(type(r), type(t))
3556
3587
 
3557
3588
 
3558
 
# TODO: Client feature that does get_bundle and then installs that into a
3559
 
# branch; this can be used in place of the regular pull/fetch operation when
3560
 
# coming from a smart server.
3561
 
#
3562
 
# TODO: Eventually, want to do a 'branch' command by fetching the whole
3563
 
# history as one big bundle.  How?  
3564
 
#
3565
 
# The branch command does 'br_from.sprout', which tries to preserve the same
3566
 
# format.  We don't necessarily even want that.  
3567
 
#
3568
 
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
3569
 
# operation from one branch into another.  It already has some code for
3570
 
# pulling from a bundle, which it does by trying to see if the destination is
3571
 
# a bundle file.  So it seems the logic for pull ought to be:
3572
 
3573
 
#  - if it's a smart server, get a bundle from there and install that
3574
 
#  - if it's a bundle, install that
3575
 
#  - if it's a branch, pull from there
3576
 
#
3577
 
# Getting a bundle from a smart server is a bit different from reading a
3578
 
# bundle from a URL:
3579
 
#
3580
 
#  - we can reasonably remember the URL we last read from 
3581
 
#  - you can specify a revision number to pull, and we need to pass it across
3582
 
#    to the server as a limit on what will be requested
3583
 
#
3584
 
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
3585
 
# otherwise whether it's a bundle?)  Should this be a property or method of
3586
 
# the transport?  For the ssh protocol, we always know it's a smart server.
3587
 
# For http, we potentially need to probe.  But if we're explicitly given
3588
 
# bzr+http:// then we can skip that for now.