/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Martin von Gagern
  • Date: 2010-04-20 08:47:38 UTC
  • mfrom: (5167 +trunk)
  • mto: This revision was merged to the branch mainline in revision 5195.
  • Revision ID: martin.vgagern@gmx.net-20100420084738-ygymnqmdllzrhpfn
merge trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Tests for smart transport"""
18
18
 
28
28
        errors,
29
29
        osutils,
30
30
        tests,
 
31
        transport,
31
32
        urlutils,
32
33
        )
33
34
from bzrlib.smart import (
39
40
        server,
40
41
        vfs,
41
42
)
42
 
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
 
43
from bzrlib.tests import test_smart
43
44
from bzrlib.transport import (
44
 
        get_transport,
 
45
        http,
45
46
        local,
46
47
        memory,
47
48
        remote,
48
49
        )
49
 
from bzrlib.transport.http import SmartClientHTTPMediumRequest
50
50
 
51
51
 
52
52
class StringIOSSHVendor(object):
68
68
 
69
69
    def __init__(self, vendor):
70
70
        self.vendor = vendor
71
 
    
 
71
 
72
72
    def close(self):
73
73
        self.vendor.calls.append(('close', ))
74
 
        
 
74
 
75
75
    def get_filelike_channels(self):
76
76
        return self.vendor.read_from, self.vendor.write_to
77
77
 
78
78
 
79
79
class _InvalidHostnameFeature(tests.Feature):
80
80
    """Does 'non_existent.invalid' fail to resolve?
81
 
    
 
81
 
82
82
    RFC 2606 states that .invalid is reserved for invalid domain names, and
83
83
    also underscores are not a valid character in domain names.  Despite this,
84
84
    it's possible a badly misconfigured name server might decide to always
132
132
        t = threading.Thread(target=_receive_bytes_on_server)
133
133
        t.start()
134
134
        return t
135
 
    
 
135
 
136
136
    def test_construct_smart_simple_pipes_client_medium(self):
137
137
        # the SimplePipes client medium takes two pipes:
138
138
        # readable pipe, writeable pipe.
139
139
        # Constructing one should just save these and do nothing.
140
140
        # We test this by passing in None.
141
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
142
 
        
 
142
 
143
143
    def test_simple_pipes_client_request_type(self):
144
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
145
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
148
148
 
149
149
    def test_simple_pipes_client_get_concurrent_requests(self):
150
150
        # the simple_pipes client does not support pipelined requests:
151
 
        # but it does support serial requests: we construct one after 
 
151
        # but it does support serial requests: we construct one after
152
152
        # another is finished. This is a smoke test testing the integration
153
153
        # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
154
154
        # classes - as the sibling classes share this logic, they do not have
170
170
            None, output, 'base')
171
171
        client_medium._accept_bytes('abc')
172
172
        self.assertEqual('abc', output.getvalue())
173
 
    
 
173
 
174
174
    def test_simple_pipes_client_disconnect_does_nothing(self):
175
175
        # calling disconnect does nothing.
176
176
        input = StringIO()
197
197
        self.assertFalse(input.closed)
198
198
        self.assertFalse(output.closed)
199
199
        self.assertEqual('abcabc', output.getvalue())
200
 
    
 
200
 
201
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
202
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
203
203
        # does nothing.
212
212
        self.assertEqual('abc', client_medium.read_bytes(3))
213
213
        client_medium.disconnect()
214
214
        self.assertEqual('def', client_medium.read_bytes(3))
215
 
        
 
215
 
216
216
    def test_simple_pipes_client_supports__flush(self):
217
 
        # invoking _flush on a SimplePipesClient should flush the output 
 
217
        # invoking _flush on a SimplePipesClient should flush the output
218
218
        # pipe. We test this by creating an output pipe that records
219
219
        # flush calls made to it.
220
220
        from StringIO import StringIO # get regular StringIO
262
262
            'a hostname', 'a port',
263
263
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
264
264
            vendor.calls)
265
 
    
266
 
    def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
267
 
        # The only thing that initiates a connection from the medium is giving
268
 
        # it bytes.
269
 
        output = StringIO()
270
 
        vendor = StringIOSSHVendor(StringIO(), output)
271
 
        orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
272
 
        def cleanup_environ():
273
 
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
274
 
        self.addCleanup(cleanup_environ)
275
 
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
276
 
        client_medium = self.callDeprecated(
277
 
            ['bzr_remote_path is required as of bzr 0.92'],
278
 
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
279
 
            'a password', 'base', vendor)
280
 
        client_medium._accept_bytes('abc')
281
 
        self.assertEqual('abc', output.getvalue())
282
 
        self.assertEqual([('connect_ssh', 'a username', 'a password',
283
 
            'a hostname', 'a port',
284
 
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
285
 
            vendor.calls)
286
 
    
 
265
 
287
266
    def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
288
267
        # The only thing that initiates a connection from the medium is giving
289
268
        # it bytes.
350
329
            ('close', ),
351
330
            ],
352
331
            vendor.calls)
353
 
    
 
332
 
354
333
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
355
334
        # Doing a disconnect on a new (and thus unconnected) SSH medium
356
335
        # does not fail.  It's ok to disconnect an unconnected medium.
369
348
                          1)
370
349
 
371
350
    def test_ssh_client_supports__flush(self):
372
 
        # invoking _flush on a SSHClientMedium should flush the output 
 
351
        # invoking _flush on a SSHClientMedium should flush the output
373
352
        # pipe. We test this by creating an output pipe that records
374
353
        # flush calls made to it.
375
354
        from StringIO import StringIO # get regular StringIO
387
366
        client_medium._flush()
388
367
        client_medium.disconnect()
389
368
        self.assertEqual(['flush'], flush_calls)
390
 
        
 
369
 
391
370
    def test_construct_smart_tcp_client_medium(self):
392
371
        # the TCP client medium takes a host and a port.  Constructing it won't
393
372
        # connect to anything.
408
387
        t.join()
409
388
        sock.close()
410
389
        self.assertEqual(['abc'], bytes)
411
 
    
 
390
 
412
391
    def test_tcp_client_disconnect_does_so(self):
413
392
        # calling disconnect on the client terminates the connection.
414
393
        # we test this by forcing a short read during a socket.MSG_WAITALL
425
404
        # really did disconnect.
426
405
        medium.disconnect()
427
406
 
428
 
    
 
407
 
429
408
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
430
409
        # Doing a disconnect on a new (and thus unconnected) TCP medium
431
410
        # does not fail.  It's ok to disconnect an unconnected medium.
468
447
 
469
448
class TestSmartClientStreamMediumRequest(tests.TestCase):
470
449
    """Tests the for SmartClientStreamMediumRequest.
471
 
    
472
 
    SmartClientStreamMediumRequest is a helper for the three stream based 
 
450
 
 
451
    SmartClientStreamMediumRequest is a helper for the three stream based
473
452
    mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
474
453
    those three mediums implement the interface it expects.
475
454
    """
476
455
 
477
456
    def test_accept_bytes_after_finished_writing_errors(self):
478
 
        # calling accept_bytes after calling finished_writing raises 
 
457
        # calling accept_bytes after calling finished_writing raises
479
458
        # WritingCompleted to prevent bad assumptions on stream environments
480
459
        # breaking the needs of message-based environments.
481
460
        output = StringIO()
537
516
            None, None, 'base')
538
517
        request = medium.SmartClientStreamMediumRequest(client_medium)
539
518
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
540
 
        
 
519
 
541
520
    def test_read_bytes(self):
542
521
        # read bytes should invoke _read_bytes on the stream medium.
543
522
        # we test this by using the SimplePipes medium - the most trivial one
544
 
        # and checking that the data is supplied. Its possible that a 
 
523
        # and checking that the data is supplied. Its possible that a
545
524
        # faulty implementation could poke at the pipe variables them selves,
546
525
        # but we trust that this will be caught as it will break the integration
547
526
        # smoke tests.
566
545
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
567
546
 
568
547
    def test_read_bytes_after_finished_reading_errors(self):
569
 
        # calling read_bytes after calling finished_reading raises 
 
548
        # calling read_bytes after calling finished_reading raises
570
549
        # ReadingCompleted to prevent bad assumptions on stream environments
571
550
        # breaking the needs of message-based environments.
572
551
        output = StringIO()
578
557
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
579
558
 
580
559
 
581
 
class RemoteTransportTests(TestCaseWithSmartMedium):
 
560
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
582
561
 
583
562
    def test_plausible_url(self):
584
563
        self.assert_(self.get_url().startswith('bzr://'))
604
583
 
605
584
 
606
585
class SampleRequest(object):
607
 
    
 
586
 
608
587
    def __init__(self, expected_bytes):
609
588
        self.accepted_bytes = ''
610
589
        self._finished_reading = False
632
611
 
633
612
    def portable_socket_pair(self):
634
613
        """Return a pair of TCP sockets connected to each other.
635
 
        
 
614
 
636
615
        Unlike socket.socketpair, this should work on Windows.
637
616
        """
638
617
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
643
622
        server_sock, addr = listen_sock.accept()
644
623
        listen_sock.close()
645
624
        return server_sock, client_sock
646
 
    
 
625
 
647
626
    def test_smart_query_version(self):
648
627
        """Feed a canned query version to a server"""
649
628
        # wire-to-wire, using the whole stack
678
657
        # wire-to-wire, using the whole stack, with a UTF-8 filename.
679
658
        transport = memory.MemoryTransport('memory:///')
680
659
        utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
 
660
        # VFS requests use filenames, not raw UTF-8.
 
661
        hpss_path = urlutils.escape(utf8_filename)
681
662
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
682
 
        to_server = StringIO('get\001' + utf8_filename + '\n')
 
663
        to_server = StringIO('get\001' + hpss_path + '\n')
683
664
        from_server = StringIO()
684
665
        server = medium.SmartServerPipeStreamMedium(
685
666
            to_server, from_server, transport)
723
704
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
724
705
        server._serve_one_request(SampleRequest('x'))
725
706
        self.assertTrue(server.finished)
726
 
        
 
707
 
727
708
    def test_socket_stream_shutdown_detection(self):
728
709
        server_sock, client_sock = self.portable_socket_pair()
729
710
        client_sock.close()
731
712
            server_sock, None)
732
713
        server._serve_one_request(SampleRequest('x'))
733
714
        self.assertTrue(server.finished)
734
 
        
 
715
 
735
716
    def test_socket_stream_incomplete_request(self):
736
717
        """The medium should still construct the right protocol version even if
737
718
        the initial read only reads part of the request.
753
734
        client_sock.sendall(rest_of_request_bytes)
754
735
        server._serve_one_request(server_protocol)
755
736
        server_sock.close()
756
 
        self.assertEqual(expected_response, client_sock.recv(50),
 
737
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
757
738
                         "Not a version 2 response to 'hello' request.")
758
739
        self.assertEqual('', client_sock.recv(1))
759
740
 
797
778
    def test_pipe_like_stream_with_two_requests(self):
798
779
        # If two requests are read in one go, then two calls to
799
780
        # _serve_one_request should still process both of them as if they had
800
 
        # been received seperately.
 
781
        # been received separately.
801
782
        sample_request_bytes = 'command\n'
802
783
        to_server = StringIO(sample_request_bytes * 2)
803
784
        from_server = StringIO()
815
796
        self.assertEqual('', from_server.getvalue())
816
797
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
817
798
        self.assertFalse(server.finished)
818
 
        
 
799
 
819
800
    def test_socket_stream_with_two_requests(self):
820
801
        # If two requests are read in one go, then two calls to
821
802
        # _serve_one_request should still process both of them as if they had
822
 
        # been received seperately.
 
803
        # been received separately.
823
804
        sample_request_bytes = 'command\n'
824
805
        server_sock, client_sock = self.portable_socket_pair()
825
806
        server = medium.SmartServerSocketStreamMedium(
856
837
        self.assertEqual('', from_server.getvalue())
857
838
        self.assertTrue(self.closed)
858
839
        self.assertTrue(server.finished)
859
 
        
 
840
 
860
841
    def test_socket_stream_error_handling(self):
861
842
        server_sock, client_sock = self.portable_socket_pair()
862
843
        server = medium.SmartServerSocketStreamMedium(
867
848
        # closed.
868
849
        self.assertEqual('', client_sock.recv(1))
869
850
        self.assertTrue(server.finished)
870
 
        
 
851
 
871
852
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
872
853
        to_server = StringIO('')
873
854
        from_server = StringIO()
918
899
        # Any empty request (i.e. no bytes) is detected as protocol version one.
919
900
        server_protocol = self.build_protocol_pipe_like('')
920
901
        self.assertProtocolOne(server_protocol)
921
 
        
 
902
 
922
903
    def test_socket_like_build_protocol_empty_bytes(self):
923
904
        # Any empty request (i.e. no bytes) is detected as protocol version one.
924
905
        server_protocol = self.build_protocol_socket('')
959
940
        self.assertEqual(
960
941
            protocol.build_server_protocol_three, protocol_factory)
961
942
        self.assertEqual('extra bytes', remainder)
962
 
        
 
943
 
963
944
    def test_version_two(self):
964
945
        result = medium._get_protocol_factory_for_bytes(
965
946
            'bzr request 2\nextra bytes')
967
948
        self.assertEqual(
968
949
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
969
950
        self.assertEqual('extra bytes', remainder)
970
 
        
 
951
 
971
952
    def test_version_one(self):
972
953
        """Version one requests have no version markers."""
973
954
        result = medium._get_protocol_factory_for_bytes('anything\n')
975
956
        self.assertEqual(
976
957
            protocol.SmartServerRequestProtocolOne, protocol_factory)
977
958
        self.assertEqual('anything\n', remainder)
978
 
        
 
959
 
979
960
 
980
961
class TestSmartTCPServer(tests.TestCase):
981
962
 
992
973
        smart_server.start_background_thread('-' + self.id())
993
974
        try:
994
975
            transport = remote.RemoteTCPTransport(smart_server.get_url())
995
 
            try:
996
 
                transport.get('something')
997
 
            except errors.TransportError, e:
998
 
                self.assertContainsRe(str(e), 'some random exception')
999
 
            else:
1000
 
                self.fail("get did not raise expected error")
 
976
            err = self.assertRaises(errors.UnknownErrorFromSmartServer,
 
977
                transport.get, 'something')
 
978
            self.assertContainsRe(str(err), 'some random exception')
1001
979
            transport.disconnect()
1002
980
        finally:
1003
981
            smart_server.stop_background_thread()
1009
987
    All of these tests are run with a server running on another thread serving
1010
988
    a MemoryTransport, and a connection to it already open.
1011
989
 
1012
 
    the server is obtained by calling self.setUpServer(readonly=False).
 
990
    the server is obtained by calling self.start_server(readonly=False).
1013
991
    """
1014
992
 
1015
 
    def setUpServer(self, readonly=False, backing_transport=None):
 
993
    def start_server(self, readonly=False, backing_transport=None):
1016
994
        """Setup the server.
1017
995
 
1018
996
        :param readonly: Create a readonly server.
1019
997
        """
 
998
        # NB: Tests using this fall into two categories: tests of the server,
 
999
        # tests wanting a server. The latter should be updated to use
 
1000
        # self.vfs_transport_factory etc.
1020
1001
        if not backing_transport:
1021
 
            self.backing_transport = memory.MemoryTransport()
 
1002
            mem_server = memory.MemoryServer()
 
1003
            mem_server.start_server()
 
1004
            self.addCleanup(mem_server.stop_server)
 
1005
            self.permit_url(mem_server.get_url())
 
1006
            self.backing_transport = transport.get_transport(
 
1007
                mem_server.get_url())
1022
1008
        else:
1023
1009
            self.backing_transport = backing_transport
1024
1010
        if readonly:
1025
1011
            self.real_backing_transport = self.backing_transport
1026
 
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
 
1012
            self.backing_transport = transport.get_transport(
 
1013
                "readonly+" + self.backing_transport.abspath('.'))
1027
1014
        self.server = server.SmartTCPServer(self.backing_transport)
1028
1015
        self.server.start_background_thread('-' + self.id())
1029
1016
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1030
1017
        self.addCleanup(self.tearDownServer)
 
1018
        self.permit_url(self.server.get_url())
1031
1019
 
1032
1020
    def tearDownServer(self):
1033
1021
        if getattr(self, 'transport', None):
1035
1023
            del self.transport
1036
1024
        if getattr(self, 'server', None):
1037
1025
            self.server.stop_background_thread()
 
1026
            # XXX: why not .stop_server() -- mbp 20100106
1038
1027
            del self.server
1039
1028
 
1040
1029
 
1042
1031
 
1043
1032
    def test_server_setup_teardown(self):
1044
1033
        """It should be safe to teardown the server with no requests."""
1045
 
        self.setUpServer()
 
1034
        self.start_server()
1046
1035
        server = self.server
1047
1036
        transport = remote.RemoteTCPTransport(self.server.get_url())
1048
1037
        self.tearDownServer()
1050
1039
 
1051
1040
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
1052
1041
        """The server should close its listening socket when it's stopped."""
1053
 
        self.setUpServer()
 
1042
        self.start_server()
1054
1043
        server = self.server
1055
1044
        self.transport.has('.')
1056
1045
        self.tearDownServer()
1065
1054
 
1066
1055
    def setUp(self):
1067
1056
        super(WritableEndToEndTests, self).setUp()
1068
 
        self.setUpServer()
 
1057
        self.start_server()
1069
1058
 
1070
1059
    def test_start_tcp_server(self):
1071
1060
        url = self.server.get_url()
1091
1080
        # asked for by the client. This gives meaningful and unsurprising errors
1092
1081
        # for users.
1093
1082
        self._captureVar('BZR_NO_SMART_VFS', None)
1094
 
        try:
1095
 
            self.transport.get('not%20a%20file')
1096
 
        except errors.NoSuchFile, e:
1097
 
            self.assertEqual('not%20a%20file', e.path)
1098
 
        else:
1099
 
            self.fail("get did not raise expected error")
 
1083
        err = self.assertRaises(
 
1084
            errors.NoSuchFile, self.transport.get, 'not%20a%20file')
 
1085
        self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file'])
1100
1086
 
1101
1087
    def test_simple_clone_conn(self):
1102
1088
        """Test that cloning reuses the same connection."""
1147
1133
    def test_mkdir_error_readonly(self):
1148
1134
        """TransportNotPossible should be preserved from the backing transport."""
1149
1135
        self._captureVar('BZR_NO_SMART_VFS', None)
1150
 
        self.setUpServer(readonly=True)
 
1136
        self.start_server(readonly=True)
1151
1137
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1152
1138
            'foo')
1153
1139
 
1163
1149
        self.hook_calls = []
1164
1150
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1165
1151
            self.capture_server_call, None)
1166
 
        self.setUpServer()
 
1152
        self.start_server()
1167
1153
        # at this point, the server will be starting a thread up.
1168
1154
        # there is no indicator at the moment, so bodge it by doing a request.
1169
1155
        self.transport.has('.')
1177
1163
        self.hook_calls = []
1178
1164
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1179
1165
            self.capture_server_call, None)
1180
 
        self.setUpServer(backing_transport=get_transport("."))
 
1166
        self.start_server(backing_transport=transport.get_transport("."))
1181
1167
        # at this point, the server will be starting a thread up.
1182
1168
        # there is no indicator at the moment, so bodge it by doing a request.
1183
1169
        self.transport.has('.')
1193
1179
        self.hook_calls = []
1194
1180
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1195
1181
            self.capture_server_call, None)
1196
 
        self.setUpServer()
 
1182
        self.start_server()
1197
1183
        result = [([self.backing_transport.base], self.transport.base)]
1198
1184
        # check the stopping message isn't emitted up front.
1199
1185
        self.assertEqual([], self.hook_calls)
1210
1196
        self.hook_calls = []
1211
1197
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1212
1198
            self.capture_server_call, None)
1213
 
        self.setUpServer(backing_transport=get_transport("."))
 
1199
        self.start_server(backing_transport=transport.get_transport("."))
1214
1200
        result = [(
1215
1201
            [self.backing_transport.base, self.backing_transport.external_url()]
1216
1202
            , self.transport.base)]
1235
1221
    Note: these tests are rudimentary versions of the command object tests in
1236
1222
    test_smart.py.
1237
1223
    """
1238
 
        
 
1224
 
1239
1225
    def test_hello(self):
1240
1226
        cmd = _mod_request.HelloRequest(None, '/')
1241
1227
        response = cmd.execute()
1242
1228
        self.assertEqual(('ok', '2'), response.args)
1243
1229
        self.assertEqual(None, response.body)
1244
 
        
 
1230
 
1245
1231
    def test_get_bundle(self):
1246
1232
        from bzrlib.bundle import serializer
1247
1233
        wt = self.make_branch_and_tree('.')
1248
1234
        self.build_tree_contents([('hello', 'hello world')])
1249
1235
        wt.add('hello')
1250
1236
        rev_id = wt.commit('add hello')
1251
 
        
 
1237
 
1252
1238
        cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1253
1239
        response = cmd.execute('.', rev_id)
1254
1240
        bundle = serializer.read_bundle(StringIO(response.body))
1275
1261
 
1276
1262
    def test_hello(self):
1277
1263
        handler = self.build_handler(None)
1278
 
        handler.dispatch_command('hello', ())
 
1264
        handler.args_received(('hello',))
1279
1265
        self.assertEqual(('ok', '2'), handler.response.args)
1280
1266
        self.assertEqual(None, handler.response.body)
1281
 
        
 
1267
 
1282
1268
    def test_disable_vfs_handler_classes_via_environment(self):
1283
1269
        # VFS handler classes will raise an error from "execute" if
1284
1270
        # BZR_NO_SMART_VFS is set.
1295
1281
        """The response for a read-only error is ('ReadOnlyError')."""
1296
1282
        handler = self.build_handler(self.get_readonly_transport())
1297
1283
        # send a mkdir for foo, with no explicit mode - should fail.
1298
 
        handler.dispatch_command('mkdir', ('foo', ''))
 
1284
        handler.args_received(('mkdir', 'foo', ''))
1299
1285
        # and the failure should be an explicit ReadOnlyError
1300
1286
        self.assertEqual(("ReadOnlyError", ), handler.response.args)
1301
1287
        # XXX: TODO: test that other TransportNotPossible errors are
1306
1292
    def test_hello_has_finished_body_on_dispatch(self):
1307
1293
        """The 'hello' command should set finished_reading."""
1308
1294
        handler = self.build_handler(None)
1309
 
        handler.dispatch_command('hello', ())
 
1295
        handler.args_received(('hello',))
1310
1296
        self.assertTrue(handler.finished_reading)
1311
1297
        self.assertNotEqual(None, handler.response)
1312
1298
 
1313
1299
    def test_put_bytes_non_atomic(self):
1314
1300
        """'put_...' should set finished_reading after reading the bytes."""
1315
1301
        handler = self.build_handler(self.get_transport())
1316
 
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
 
1302
        handler.args_received(('put_non_atomic', 'a-file', '', 'F', ''))
1317
1303
        self.assertFalse(handler.finished_reading)
1318
1304
        handler.accept_body('1234')
1319
1305
        self.assertFalse(handler.finished_reading)
1322
1308
        self.assertTrue(handler.finished_reading)
1323
1309
        self.assertEqual(('ok', ), handler.response.args)
1324
1310
        self.assertEqual(None, handler.response.body)
1325
 
        
 
1311
 
1326
1312
    def test_readv_accept_body(self):
1327
1313
        """'readv' should set finished_reading after reading offsets."""
1328
1314
        self.build_tree(['a-file'])
1329
1315
        handler = self.build_handler(self.get_readonly_transport())
1330
 
        handler.dispatch_command('readv', ('a-file', ))
 
1316
        handler.args_received(('readv', 'a-file'))
1331
1317
        self.assertFalse(handler.finished_reading)
1332
1318
        handler.accept_body('2,')
1333
1319
        self.assertFalse(handler.finished_reading)
1342
1328
        """'readv' when a short read occurs sets the response appropriately."""
1343
1329
        self.build_tree(['a-file'])
1344
1330
        handler = self.build_handler(self.get_readonly_transport())
1345
 
        handler.dispatch_command('readv', ('a-file', ))
 
1331
        handler.args_received(('readv', 'a-file'))
1346
1332
        # read beyond the end of the file.
1347
1333
        handler.accept_body('100,1')
1348
1334
        handler.end_of_body()
1355
1341
class RemoteTransportRegistration(tests.TestCase):
1356
1342
 
1357
1343
    def test_registration(self):
1358
 
        t = get_transport('bzr+ssh://example.com/path')
 
1344
        t = transport.get_transport('bzr+ssh://example.com/path')
1359
1345
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1360
1346
        self.assertEqual('example.com', t._host)
1361
1347
 
1362
1348
    def test_bzr_https(self):
1363
1349
        # https://bugs.launchpad.net/bzr/+bug/128456
1364
 
        t = get_transport('bzr+https://example.com/path')
 
1350
        t = transport.get_transport('bzr+https://example.com/path')
1365
1351
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1366
1352
        self.assertStartsWith(
1367
1353
            t._http_transport.base,
1369
1355
 
1370
1356
 
1371
1357
class TestRemoteTransport(tests.TestCase):
1372
 
        
 
1358
 
1373
1359
    def test_use_connection_factory(self):
1374
1360
        # We want to be able to pass a client as a parameter to RemoteTransport.
1375
1361
        input = StringIO('ok\n3\nbardone\n')
1400
1386
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
1401
1387
        transport = remote.RemoteTransport(
1402
1388
            'bzr://localhost/', medium=client_medium)
 
1389
        err = errors.ErrorFromSmartServer(("ReadOnlyError", ))
1403
1390
        self.assertRaises(errors.TransportNotPossible,
1404
 
            transport._translate_error, ("ReadOnlyError", ))
 
1391
            transport._translate_error, err)
1405
1392
 
1406
1393
 
1407
1394
class TestSmartProtocol(tests.TestCase):
1473
1460
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1474
1461
        requester):
1475
1462
        """Check that smart (de)serialises offsets as expected.
1476
 
        
 
1463
 
1477
1464
        We check both serialisation and deserialisation at the same time
1478
1465
        to ensure that the round tripping cannot skew: both directions should
1479
1466
        be as expected.
1480
 
        
 
1467
 
1481
1468
        :param expected_offsets: a readv offset list.
1482
1469
        :param expected_seralised: an expected serial form of the offsets.
1483
1470
        """
1494
1481
        smart_protocol._has_dispatched = True
1495
1482
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
1496
1483
            None, _mod_request.request_handlers, '/')
1497
 
        class FakeCommand(object):
1498
 
            def do_body(cmd, body_bytes):
 
1484
        class FakeCommand(_mod_request.SmartServerRequest):
 
1485
            def do_body(self_cmd, body_bytes):
1499
1486
                self.end_received = True
1500
1487
                self.assertEqual('abcdefg', body_bytes)
1501
1488
                return _mod_request.SuccessfulSmartServerResponse(('ok', ))
1502
 
        smart_protocol.request._command = FakeCommand()
 
1489
        smart_protocol.request._command = FakeCommand(None)
1503
1490
        # Call accept_bytes to make sure that internal state like _body_decoder
1504
1491
        # is initialised.  This test should probably be given a clearer
1505
1492
        # interface to work with that will not cause this inconsistency.
1533
1520
        ex = self.assertRaises(errors.ConnectionReset,
1534
1521
            response_handler.read_response_tuple)
1535
1522
        self.assertEqual("Connection closed: "
1536
 
            "please check connectivity and permissions "
1537
 
            "(and try -Dhpss if further diagnosis is required)", str(ex))
 
1523
            "Unexpected end of message. Please check connectivity "
 
1524
            "and permissions, and report a bug if problems persist. ",
 
1525
            str(ex))
1538
1526
 
1539
1527
    def test_server_offset_serialisation(self):
1540
1528
        """The Smart protocol serialises offsets as a comma and \n string.
1659
1647
 
1660
1648
    def test_query_version(self):
1661
1649
        """query_version on a SmartClientProtocolOne should return a number.
1662
 
        
 
1650
 
1663
1651
        The protocol provides the query_version because the domain level clients
1664
1652
        may all need to be able to probe for capabilities.
1665
1653
        """
1666
1654
        # What we really want to test here is that SmartClientProtocolOne calls
1667
1655
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1668
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1656
        # response of tuple-encoded (ok, 1).  Also, separately we should test
1669
1657
        # the error if the response is a non-understood version.
1670
1658
        input = StringIO('ok\x012\n')
1671
1659
        output = StringIO()
1930
1918
 
1931
1919
    def test_query_version(self):
1932
1920
        """query_version on a SmartClientProtocolTwo should return a number.
1933
 
        
 
1921
 
1934
1922
        The protocol provides the query_version because the domain level clients
1935
1923
        may all need to be able to probe for capabilities.
1936
1924
        """
1937
1925
        # What we really want to test here is that SmartClientProtocolTwo calls
1938
1926
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1939
 
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1927
        # response of tuple-encoded (ok, 1).  Also, separately we should test
1940
1928
        # the error if the response is a non-understood version.
1941
1929
        input = StringIO(self.response_marker + 'success\nok\x012\n')
1942
1930
        output = StringIO()
2275
2263
        self.assertEqual(4, smart_protocol.next_read_size())
2276
2264
 
2277
2265
 
2278
 
class NoOpRequest(_mod_request.SmartServerRequest):
2279
 
 
2280
 
    def do(self):
2281
 
        return _mod_request.SuccessfulSmartServerResponse(())
2282
 
 
2283
 
dummy_registry = {'ARG': NoOpRequest}
2284
 
 
2285
 
 
2286
2266
class LoggingMessageHandler(object):
2287
2267
 
2288
2268
    def __init__(self):
2330
2310
        self.assertEqual(0, smart_protocol.next_read_size())
2331
2311
        self.assertEqual('', smart_protocol.unused_data)
2332
2312
 
 
2313
    def test_repeated_excess(self):
 
2314
        """Repeated calls to accept_bytes after the message end has been parsed
 
2315
        accumlates the bytes in the unused_data attribute.
 
2316
        """
 
2317
        output = StringIO()
 
2318
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2319
        end = 'e'
 
2320
        request_bytes = headers + end
 
2321
        smart_protocol = self.server_protocol_class(LoggingMessageHandler())
 
2322
        smart_protocol.accept_bytes(request_bytes)
 
2323
        self.assertEqual('', smart_protocol.unused_data)
 
2324
        smart_protocol.accept_bytes('aaa')
 
2325
        self.assertEqual('aaa', smart_protocol.unused_data)
 
2326
        smart_protocol.accept_bytes('bbb')
 
2327
        self.assertEqual('aaabbb', smart_protocol.unused_data)
 
2328
        self.assertEqual(0, smart_protocol.next_read_size())
 
2329
 
2333
2330
    def make_protocol_expecting_message_part(self):
2334
2331
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
2335
2332
        message_handler = LoggingMessageHandler()
2363
2360
            '\0\0\0\x07' # length prefix
2364
2361
            'l3:ARGe' # ['ARG']
2365
2362
            )
2366
 
        self.assertEqual([('structure', ['ARG'])], event_log)
 
2363
        self.assertEqual([('structure', ('ARG',))], event_log)
2367
2364
 
2368
2365
    def test_decode_multiple_bytes(self):
2369
2366
        """The protocol can decode a multiple 'bytes' message parts."""
2380
2377
            [('bytes', 'first'), ('bytes', 'second')], event_log)
2381
2378
 
2382
2379
 
2383
 
class TestConventionalResponseHandler(tests.TestCase):
 
2380
class TestConventionalResponseHandlerBodyStream(tests.TestCase):
2384
2381
 
2385
2382
    def make_response_handler(self, response_bytes):
2386
2383
        from bzrlib.smart.message import ConventionalResponseHandler
2397
2394
            protocol_decoder, medium_request)
2398
2395
        return response_handler
2399
2396
 
2400
 
    def test_body_stream_interrupted_by_error(self):
2401
 
        interrupted_body_stream = (
2402
 
            'oS' # successful response
2403
 
            's\0\0\0\x02le' # empty args
2404
 
            'b\0\0\0\x09chunk one' # first chunk
2405
 
            'b\0\0\0\x09chunk two' # second chunk
2406
 
            'oE' # error flag
2407
 
            's\0\0\0\x0el5:error3:abce' # bencoded error
2408
 
            'e' # message end
2409
 
            )
 
2397
    def test_interrupted_by_error(self):
2410
2398
        response_handler = self.make_response_handler(interrupted_body_stream)
2411
2399
        stream = response_handler.read_streamed_body()
2412
 
        self.assertEqual('chunk one', stream.next())
2413
 
        self.assertEqual('chunk two', stream.next())
 
2400
        self.assertEqual('aaa', stream.next())
 
2401
        self.assertEqual('bbb', stream.next())
2414
2402
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2415
 
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2403
        self.assertEqual(('error', 'Boom!'), exc.error_tuple)
2416
2404
 
2417
 
    def test_body_stream_interrupted_by_connection_lost(self):
 
2405
    def test_interrupted_by_connection_lost(self):
2418
2406
        interrupted_body_stream = (
2419
2407
            'oS' # successful response
2420
2408
            's\0\0\0\x02le' # empty args
2432
2420
        self.assertRaises(
2433
2421
            errors.ConnectionReset, response_handler.read_body_bytes)
2434
2422
 
 
2423
    def test_multiple_bytes_parts(self):
 
2424
        multiple_bytes_parts = (
 
2425
            'oS' # successful response
 
2426
            's\0\0\0\x02le' # empty args
 
2427
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2428
            'b\0\0\0\x0aMore bytes' # more bytes
 
2429
            'e' # message end
 
2430
            )
 
2431
        response_handler = self.make_response_handler(multiple_bytes_parts)
 
2432
        self.assertEqual(
 
2433
            'Some bytes\nMore bytes', response_handler.read_body_bytes())
 
2434
        response_handler = self.make_response_handler(multiple_bytes_parts)
 
2435
        self.assertEqual(
 
2436
            ['Some bytes\n', 'More bytes'],
 
2437
            list(response_handler.read_streamed_body()))
 
2438
 
 
2439
 
 
2440
class FakeResponder(object):
 
2441
 
 
2442
    response_sent = False
 
2443
 
 
2444
    def send_error(self, exc):
 
2445
        raise exc
 
2446
 
 
2447
    def send_response(self, response):
 
2448
        pass
 
2449
 
 
2450
 
 
2451
class TestConventionalRequestHandlerBodyStream(tests.TestCase):
 
2452
    """Tests for ConventionalRequestHandler's handling of request bodies."""
 
2453
 
 
2454
    def make_request_handler(self, request_bytes):
 
2455
        """Make a ConventionalRequestHandler for the given bytes using test
 
2456
        doubles for the request_handler and the responder.
 
2457
        """
 
2458
        from bzrlib.smart.message import ConventionalRequestHandler
 
2459
        request_handler = InstrumentedRequestHandler()
 
2460
        request_handler.response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
 
2461
        responder = FakeResponder()
 
2462
        message_handler = ConventionalRequestHandler(request_handler, responder)
 
2463
        protocol_decoder = protocol.ProtocolThreeDecoder(message_handler)
 
2464
        # put decoder in desired state (waiting for message parts)
 
2465
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
 
2466
        protocol_decoder.accept_bytes(request_bytes)
 
2467
        return request_handler
 
2468
 
 
2469
    def test_multiple_bytes_parts(self):
 
2470
        """Each bytes part triggers a call to the request_handler's
 
2471
        accept_body method.
 
2472
        """
 
2473
        multiple_bytes_parts = (
 
2474
            's\0\0\0\x07l3:fooe' # args
 
2475
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2476
            'b\0\0\0\x0aMore bytes' # more bytes
 
2477
            'e' # message end
 
2478
            )
 
2479
        request_handler = self.make_request_handler(multiple_bytes_parts)
 
2480
        accept_body_calls = [
 
2481
            call_info[1] for call_info in request_handler.calls
 
2482
            if call_info[0] == 'accept_body']
 
2483
        self.assertEqual(
 
2484
            ['Some bytes\n', 'More bytes'], accept_body_calls)
 
2485
 
 
2486
    def test_error_flag_after_body(self):
 
2487
        body_then_error = (
 
2488
            's\0\0\0\x07l3:fooe' # request args
 
2489
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2490
            'b\0\0\0\x0aMore bytes' # more bytes
 
2491
            'oE' # error flag
 
2492
            's\0\0\0\x07l3:bare' # error args
 
2493
            'e' # message end
 
2494
            )
 
2495
        request_handler = self.make_request_handler(body_then_error)
 
2496
        self.assertEqual(
 
2497
            [('post_body_error_received', ('bar',)), ('end_received',)],
 
2498
            request_handler.calls[-2:])
 
2499
 
2435
2500
 
2436
2501
class TestMessageHandlerErrors(tests.TestCase):
2437
2502
    """Tests for v3 that unrecognised (but well-formed) requests/responses are
2481
2546
 
2482
2547
    def __init__(self):
2483
2548
        self.calls = []
2484
 
 
2485
 
    def body_chunk_received(self, chunk_bytes):
2486
 
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2549
        self.finished_reading = False
2487
2550
 
2488
2551
    def no_body_received(self):
2489
2552
        self.calls.append(('no_body_received',))
2490
2553
 
2491
 
    def prefixed_body_received(self, body_bytes):
2492
 
        self.calls.append(('prefixed_body_received', body_bytes))
2493
 
 
2494
2554
    def end_received(self):
2495
2555
        self.calls.append(('end_received',))
 
2556
        self.finished_reading = True
 
2557
 
 
2558
    def args_received(self, args):
 
2559
        self.calls.append(('args_received', args))
 
2560
 
 
2561
    def accept_body(self, bytes):
 
2562
        self.calls.append(('accept_body', bytes))
 
2563
 
 
2564
    def end_of_body(self):
 
2565
        self.calls.append(('end_of_body',))
 
2566
        self.finished_reading = True
 
2567
 
 
2568
    def post_body_error_received(self, error_args):
 
2569
        self.calls.append(('post_body_error_received', error_args))
2496
2570
 
2497
2571
 
2498
2572
class StubRequest(object):
2534
2608
        # The message handler has been invoked with all the parts of the
2535
2609
        # trivial response: empty headers, status byte, no args, end.
2536
2610
        self.assertEqual(
2537
 
            [('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
 
2611
            [('headers', {}), ('byte', 'S'), ('structure', ()), ('end',)],
2538
2612
            response_handler.event_log)
2539
2613
 
2540
2614
    def test_incomplete_message(self):
2648
2722
        self.assertEqual(
2649
2723
            ['accept_bytes', 'finished_writing'], medium_request.calls)
2650
2724
 
 
2725
    def test_call_with_body_stream_smoke_test(self):
 
2726
        """A smoke test for ProtocolThreeRequester.call_with_body_stream.
 
2727
 
 
2728
        This test checks that a particular simple invocation of
 
2729
        call_with_body_stream emits the correct bytes for that invocation.
 
2730
        """
 
2731
        requester, output = self.make_client_encoder_and_output()
 
2732
        requester.set_headers({'header name': 'header value'})
 
2733
        stream = ['chunk 1', 'chunk two']
 
2734
        requester.call_with_body_stream(('one arg',), stream)
 
2735
        self.assertEquals(
 
2736
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2737
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2738
            's\x00\x00\x00\x0bl7:one arge' # args
 
2739
            'b\x00\x00\x00\x07chunk 1' # a prefixed body chunk
 
2740
            'b\x00\x00\x00\x09chunk two' # a prefixed body chunk
 
2741
            'e', # end
 
2742
            output.getvalue())
 
2743
 
 
2744
    def test_call_with_body_stream_empty_stream(self):
 
2745
        """call_with_body_stream with an empty stream."""
 
2746
        requester, output = self.make_client_encoder_and_output()
 
2747
        requester.set_headers({})
 
2748
        stream = []
 
2749
        requester.call_with_body_stream(('one arg',), stream)
 
2750
        self.assertEquals(
 
2751
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2752
            '\x00\x00\x00\x02de' # headers
 
2753
            's\x00\x00\x00\x0bl7:one arge' # args
 
2754
            # no body chunks
 
2755
            'e', # end
 
2756
            output.getvalue())
 
2757
 
 
2758
    def test_call_with_body_stream_error(self):
 
2759
        """call_with_body_stream will abort the streamed body with an
 
2760
        error if the stream raises an error during iteration.
 
2761
 
 
2762
        The resulting request will still be a complete message.
 
2763
        """
 
2764
        requester, output = self.make_client_encoder_and_output()
 
2765
        requester.set_headers({})
 
2766
        def stream_that_fails():
 
2767
            yield 'aaa'
 
2768
            yield 'bbb'
 
2769
            raise Exception('Boom!')
 
2770
        self.assertRaises(Exception, requester.call_with_body_stream,
 
2771
            ('one arg',), stream_that_fails())
 
2772
        self.assertEquals(
 
2773
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2774
            '\x00\x00\x00\x02de' # headers
 
2775
            's\x00\x00\x00\x0bl7:one arge' # args
 
2776
            'b\x00\x00\x00\x03aaa' # body
 
2777
            'b\x00\x00\x00\x03bbb' # more body
 
2778
            'oE' # error flag
 
2779
            's\x00\x00\x00\x09l5:errore' # error args: ('error',)
 
2780
            'e', # end
 
2781
            output.getvalue())
 
2782
 
2651
2783
 
2652
2784
class StubMediumRequest(object):
2653
2785
    """A stub medium request that tracks the number of times accept_bytes is
2665
2797
        self.calls.append('finished_writing')
2666
2798
 
2667
2799
 
 
2800
interrupted_body_stream = (
 
2801
    'oS' # status flag (success)
 
2802
    's\x00\x00\x00\x08l4:argse' # args struct ('args,')
 
2803
    'b\x00\x00\x00\x03aaa' # body part ('aaa')
 
2804
    'b\x00\x00\x00\x03bbb' # body part ('bbb')
 
2805
    'oE' # status flag (error)
 
2806
    's\x00\x00\x00\x10l5:error5:Boom!e' # err struct ('error', 'Boom!')
 
2807
    'e' # EOM
 
2808
    )
 
2809
 
 
2810
 
2668
2811
class TestResponseEncodingProtocolThree(tests.TestCase):
2669
2812
 
2670
2813
    def make_response_encoder(self):
2686
2829
            # end of message
2687
2830
            'e')
2688
2831
 
 
2832
    def test_send_broken_body_stream(self):
 
2833
        encoder, out_stream = self.make_response_encoder()
 
2834
        encoder._headers = {}
 
2835
        def stream_that_fails():
 
2836
            yield 'aaa'
 
2837
            yield 'bbb'
 
2838
            raise Exception('Boom!')
 
2839
        response = _mod_request.SuccessfulSmartServerResponse(
 
2840
            ('args',), body_stream=stream_that_fails())
 
2841
        encoder.send_response(response)
 
2842
        expected_response = (
 
2843
            'bzr message 3 (bzr 1.6)\n'  # protocol marker
 
2844
            '\x00\x00\x00\x02de' # headers dict (empty)
 
2845
            + interrupted_body_stream)
 
2846
        self.assertEqual(expected_response, out_stream.getvalue())
 
2847
 
2689
2848
 
2690
2849
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
2691
2850
    """Tests for buffering of responses.
2695
2854
    """
2696
2855
 
2697
2856
    def setUp(self):
 
2857
        tests.TestCase.setUp(self)
2698
2858
        self.writes = []
2699
2859
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2700
2860
 
2702
2862
        self.assertEqual(
2703
2863
            expected_count, len(self.writes),
2704
2864
            "Too many writes: %r" % (self.writes,))
2705
 
        
 
2865
 
2706
2866
    def test_send_error_writes_just_once(self):
2707
2867
        """An error response is written to the medium all at once."""
2708
2868
        self.responder.send_error(Exception('An exception string.'))
2724
2884
        self.responder.send_response(response)
2725
2885
        self.assertWriteCount(1)
2726
2886
 
2727
 
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
2728
 
        """A normal response with a stream body is written to the medium
2729
 
        writes to the medium once per chunk.
2730
 
        """
 
2887
    def test_send_response_with_body_stream_buffers_writes(self):
 
2888
        """A normal response with a stream body writes to the medium once."""
2731
2889
        # Construct a response with stream with 2 chunks in it.
2732
2890
        response = _mod_request.SuccessfulSmartServerResponse(
2733
2891
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2734
2892
        self.responder.send_response(response)
2735
 
        # We will write 3 times: exactly once for each chunk, plus a final
2736
 
        # write to end the response.
2737
 
        self.assertWriteCount(3)
 
2893
        # We will write just once, despite the multiple chunks, due to
 
2894
        # buffering.
 
2895
        self.assertWriteCount(1)
 
2896
 
 
2897
    def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
 
2898
        """When there are many bytes (>1MB), multiple writes will occur rather
 
2899
        than buffering indefinitely.
 
2900
        """
 
2901
        # Construct a response with stream with ~1.5MB in it. This should
 
2902
        # trigger 2 writes, but not 3
 
2903
        onekib = '12345678' * 128
 
2904
        body_stream = [onekib] * (1024 + 512)
 
2905
        response = _mod_request.SuccessfulSmartServerResponse(
 
2906
            ('arg', 'arg'), body_stream=body_stream)
 
2907
        self.responder.send_response(response)
 
2908
        self.assertWriteCount(2)
2738
2909
 
2739
2910
 
2740
2911
class TestSmartClientUnicode(tests.TestCase):
2777
2948
 
2778
2949
class MockMedium(medium.SmartClientMedium):
2779
2950
    """A mock medium that can be used to test _SmartClient.
2780
 
    
 
2951
 
2781
2952
    It can be given a series of requests to expect (and responses it should
2782
2953
    return for them).  It can also be told when the client is expected to
2783
2954
    disconnect a medium.  Expectations must be satisfied in the order they are
2795
2966
        super(MockMedium, self).__init__('dummy base')
2796
2967
        self._mock_request = _MockMediumRequest(self)
2797
2968
        self._expected_events = []
2798
 
        
 
2969
 
2799
2970
    def expect_request(self, request_bytes, response_bytes,
2800
2971
                       allow_partial_read=False):
2801
2972
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
2804
2975
        called to send the request.  Similarly, no assumption is made about how
2805
2976
        many times read_bytes/read_line are called by protocol code to read a
2806
2977
        response.  e.g.::
2807
 
        
 
2978
 
2808
2979
            request.accept_bytes('ab')
2809
2980
            request.accept_bytes('cd')
2810
2981
            request.finished_writing()
2811
2982
 
2812
2983
        and::
2813
 
        
 
2984
 
2814
2985
            request.accept_bytes('abcd')
2815
2986
            request.finished_writing()
2816
2987
 
3001
3172
    def test_first_response_is_error(self):
3002
3173
        """If the server replies with an error, then the version detection
3003
3174
        should be complete.
3004
 
        
 
3175
 
3005
3176
        This test is very similar to test_version_two_server, but catches a bug
3006
3177
        we had in the case where the first reply was an error response.
3007
3178
        """
3047
3218
 
3048
3219
class LengthPrefixedBodyDecoder(tests.TestCase):
3049
3220
 
3050
 
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
 
3221
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3051
3222
    # something similar to the ProtocolBase method.
3052
3223
 
3053
3224
    def test_construct(self):
3089
3260
        self.assertEqual(1, decoder.next_read_size())
3090
3261
        self.assertEqual('', decoder.read_pending_data())
3091
3262
        self.assertEqual('blarg', decoder.unused_data)
3092
 
        
 
3263
 
3093
3264
    def test_accept_bytes_all_at_once_with_excess(self):
3094
3265
        decoder = protocol.LengthPrefixedBodyDecoder()
3095
3266
        decoder.accept_bytes('1\nadone\nunused')
3114
3285
 
3115
3286
class TestChunkedBodyDecoder(tests.TestCase):
3116
3287
    """Tests for ChunkedBodyDecoder.
3117
 
    
 
3288
 
3118
3289
    This is the body decoder used for protocol version two.
3119
3290
    """
3120
3291
 
3146
3317
        self.assertTrue(decoder.finished_reading)
3147
3318
        self.assertEqual(chunk_content, decoder.read_next_chunk())
3148
3319
        self.assertEqual('', decoder.unused_data)
3149
 
        
 
3320
 
3150
3321
    def test_incomplete_chunk(self):
3151
3322
        """When there are less bytes in the chunk than declared by the length,
3152
3323
        then we haven't finished reading yet.
3347
3518
 
3348
3519
    def test_smart_http_medium_request_accept_bytes(self):
3349
3520
        medium = FakeHTTPMedium()
3350
 
        request = SmartClientHTTPMediumRequest(medium)
 
3521
        request = http.SmartClientHTTPMediumRequest(medium)
3351
3522
        request.accept_bytes('abc')
3352
3523
        request.accept_bytes('def')
3353
3524
        self.assertEqual(None, medium.written_request)
3362
3533
        # requests for child URLs of that to the original URL.  i.e., we want to
3363
3534
        # POST to "bzr+http://host/foo/.bzr/smart" and never something like
3364
3535
        # "bzr+http://host/foo/.bzr/branch/.bzr/smart".  So, a cloned
3365
 
        # RemoteHTTPTransport remembers the initial URL, and adjusts the relpaths
3366
 
        # it sends in smart requests accordingly.
 
3536
        # RemoteHTTPTransport remembers the initial URL, and adjusts the
 
3537
        # relpaths it sends in smart requests accordingly.
3367
3538
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
3368
3539
        new_transport = base_transport.clone('child_dir')
3369
3540
        self.assertEqual(base_transport._http_transport,
3384
3555
        # still work correctly.
3385
3556
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3386
3557
        new_transport = base_transport.clone('c')
3387
 
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
3558
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
3388
3559
        self.assertEqual(
3389
3560
            'c/',
3390
3561
            new_transport._client.remote_path_from_transport(new_transport))
3391
3562
 
3392
 
        
3393
 
# TODO: Client feature that does get_bundle and then installs that into a
3394
 
# branch; this can be used in place of the regular pull/fetch operation when
3395
 
# coming from a smart server.
3396
 
#
3397
 
# TODO: Eventually, want to do a 'branch' command by fetching the whole
3398
 
# history as one big bundle.  How?  
3399
 
#
3400
 
# The branch command does 'br_from.sprout', which tries to preserve the same
3401
 
# format.  We don't necessarily even want that.  
3402
 
#
3403
 
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
3404
 
# operation from one branch into another.  It already has some code for
3405
 
# pulling from a bundle, which it does by trying to see if the destination is
3406
 
# a bundle file.  So it seems the logic for pull ought to be:
3407
 
3408
 
#  - if it's a smart server, get a bundle from there and install that
3409
 
#  - if it's a bundle, install that
3410
 
#  - if it's a branch, pull from there
3411
 
#
3412
 
# Getting a bundle from a smart server is a bit different from reading a
3413
 
# bundle from a URL:
3414
 
#
3415
 
#  - we can reasonably remember the URL we last read from 
3416
 
#  - you can specify a revision number to pull, and we need to pass it across
3417
 
#    to the server as a limit on what will be requested
3418
 
#
3419
 
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
3420
 
# otherwise whether it's a bundle?)  Should this be a property or method of
3421
 
# the transport?  For the ssh protocol, we always know it's a smart server.
3422
 
# For http, we potentially need to probe.  But if we're explicitly given
3423
 
# bzr+http:// then we can skip that for now. 
 
3563
    def test__redirect_to(self):
 
3564
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
 
3565
        r = t._redirected_to('http://www.example.com/foo',
 
3566
                             'http://www.example.com/bar')
 
3567
        self.assertEquals(type(r), type(t))
 
3568
 
 
3569
    def test__redirect_sibling_protocol(self):
 
3570
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
 
3571
        r = t._redirected_to('http://www.example.com/foo',
 
3572
                             'https://www.example.com/bar')
 
3573
        self.assertEquals(type(r), type(t))
 
3574
        self.assertStartsWith(r.base, 'bzr+https')
 
3575
 
 
3576
    def test__redirect_to_with_user(self):
 
3577
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
 
3578
        r = t._redirected_to('http://www.example.com/foo',
 
3579
                             'http://www.example.com/bar')
 
3580
        self.assertEquals(type(r), type(t))
 
3581
        self.assertEquals('joe', t._user)
 
3582
        self.assertEquals(t._user, r._user)
 
3583
 
 
3584
    def test_redirected_to_same_host_different_protocol(self):
 
3585
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
 
3586
        r = t._redirected_to('http://www.example.com/foo',
 
3587
                             'ftp://www.example.com/foo')
 
3588
        self.assertNotEquals(type(r), type(t))
 
3589
 
 
3590