/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/tests/test_smart_transport.py

  • Committer: Jelmer Vernooij
  • Date: 2018-11-16 23:19:12 UTC
  • mfrom: (7180 work)
  • mto: This revision was merged to the branch mainline in revision 7294.
  • Revision ID: jelmer@jelmer.uk-20181116231912-e043vpq22bdkxa6q
Merge trunk.

Show diffs side-by-side

added added

removed removed

Lines of Context:
101
101
 
102
102
    def connect_ssh(self, username, password, host, port, command):
103
103
        self.calls.append(('connect_ssh', username, password, host, port,
104
 
            command))
 
104
                           command))
105
105
        return BytesIOSSHConnection(self)
106
106
 
107
107
 
113
113
 
114
114
    def __init__(self, read_from, write_to, fail_at_write=True):
115
115
        super(FirstRejectedBytesIOSSHVendor, self).__init__(read_from,
116
 
            write_to)
 
116
                                                            write_to)
117
117
        self.fail_at_write = fail_at_write
118
118
        self._first = True
119
119
 
120
120
    def connect_ssh(self, username, password, host, port, command):
121
121
        self.calls.append(('connect_ssh', username, password, host, port,
122
 
            command))
 
122
                           command))
123
123
        if self._first:
124
124
            self._first = False
125
125
            return ClosedSSHConnection(self)
186
186
    def feature_name(self):
187
187
        return 'invalid hostname'
188
188
 
 
189
 
189
190
InvalidHostnameFeature = _InvalidHostnameFeature()
190
191
 
191
192
 
265
266
        # On Windows, if you use os.pipe() and close the write side,
266
267
        # read.read() hangs. On Linux, read.read() returns the empty string.
267
268
        p = subprocess.Popen([sys.executable, '-c',
268
 
            'import sys\n'
269
 
            'sys.stdout.write(sys.stdin.read(4))\n'
270
 
            'sys.stdout.close()\n'],
271
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
 
269
                              'import sys\n'
 
270
                              'sys.stdout.write(sys.stdin.read(4))\n'
 
271
                              'sys.stdout.close()\n'],
 
272
                             stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
272
273
        client_medium = medium.SmartSimplePipesClientMedium(
273
274
            p.stdout, p.stdin, 'base')
274
275
        client_medium._accept_bytes(b'abc\n')
307
308
 
308
309
    def test_simple_pipes__flush_subprocess_closed(self):
309
310
        p = subprocess.Popen([sys.executable, '-c',
310
 
            'import sys\n'
311
 
            'sys.stdout.write(sys.stdin.read(4))\n'
312
 
            'sys.stdout.close()\n'],
313
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
 
311
                              'import sys\n'
 
312
                              'sys.stdout.write(sys.stdin.read(4))\n'
 
313
                              'sys.stdout.close()\n'],
 
314
                             stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
314
315
        client_medium = medium.SmartSimplePipesClientMedium(
315
316
            p.stdout, p.stdin, 'base')
316
317
        client_medium._accept_bytes(b'abc\n')
329
330
 
330
331
    def test_simple_pipes__read_bytes_subprocess_closed(self):
331
332
        p = subprocess.Popen([sys.executable, '-c',
332
 
            'import sys\n'
333
 
            'if sys.platform == "win32":\n'
334
 
            '    import msvcrt, os\n'
335
 
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
336
 
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
337
 
            'sys.stdout.write(sys.stdin.read(4))\n'
338
 
            'sys.stdout.close()\n'],
339
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
 
333
                              'import sys\n'
 
334
                              'if sys.platform == "win32":\n'
 
335
                              '    import msvcrt, os\n'
 
336
                              '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
337
                              '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
338
                              'sys.stdout.write(sys.stdin.read(4))\n'
 
339
                              'sys.stdout.close()\n'],
 
340
                             stdout=subprocess.PIPE, stdin=subprocess.PIPE, bufsize=0)
340
341
        client_medium = medium.SmartSimplePipesClientMedium(
341
342
            p.stdout, p.stdin, 'base')
342
343
        client_medium._accept_bytes(b'abc\n')
381
382
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
382
383
        # always tries to read from the underlying pipe.
383
384
        input = BytesIO(b'abcdef')
384
 
        client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
 
385
        client_medium = medium.SmartSimplePipesClientMedium(
 
386
            input, None, 'base')
385
387
        self.assertEqual(b'abc', client_medium.read_bytes(3))
386
388
        client_medium.disconnect()
387
389
        self.assertEqual(b'def', client_medium.read_bytes(3))
390
392
        # invoking _flush on a SimplePipesClient should flush the output
391
393
        # pipe. We test this by creating an output pipe that records
392
394
        # flush calls made to it.
393
 
        from io import BytesIO # get regular BytesIO
 
395
        from io import BytesIO  # get regular BytesIO
394
396
        input = BytesIO()
395
397
        output = BytesIO()
396
398
        flush_calls = []
 
399
 
397
400
        def logging_flush(): flush_calls.append('flush')
398
401
        output.flush = logging_flush
399
402
        client_medium = medium.SmartSimplePipesClientMedium(
432
435
        client_medium._accept_bytes(b'abc')
433
436
        self.assertEqual(b'abc', output.getvalue())
434
437
        self.assertEqual([('connect_ssh', 'a username', 'a password',
435
 
            'a hostname', 'a port',
436
 
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
437
 
            vendor.calls)
 
438
                           'a hostname', 'a port',
 
439
                           ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
440
                         vendor.calls)
438
441
 
439
442
    def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
440
443
        # The only thing that initiates a connection from the medium is giving
448
451
        client_medium._accept_bytes(b'abc')
449
452
        self.assertEqual(b'abc', output.getvalue())
450
453
        self.assertEqual([('connect_ssh', 'a username', 'a password',
451
 
            'a hostname', 'a port',
452
 
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
453
 
            vendor.calls)
 
454
                           'a hostname', 'a port',
 
455
                           ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
456
                         vendor.calls)
454
457
 
455
458
    def test_ssh_client_disconnect_does_so(self):
456
459
        # calling disconnect should disconnect both the read_from and write_to
466
469
        self.assertTrue(output.closed)
467
470
        self.assertEqual([
468
471
            ('connect_ssh', None, None, 'a hostname', None,
469
 
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
472
             ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
470
473
            ('close', ),
471
474
            ],
472
475
            vendor.calls)
497
500
        self.assertTrue(output2.closed)
498
501
        self.assertEqual([
499
502
            ('connect_ssh', None, None, 'a hostname', None,
500
 
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
503
             ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
501
504
            ('close', ),
502
505
            ('connect_ssh', None, None, 'a hostname', None,
503
 
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
506
             ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
504
507
            ('close', ),
505
508
            ],
506
509
            vendor.calls)
547
550
        # invoking _flush on a SSHClientMedium should flush the output
548
551
        # pipe. We test this by creating an output pipe that records
549
552
        # flush calls made to it.
550
 
        from io import BytesIO # get regular BytesIO
 
553
        from io import BytesIO  # get regular BytesIO
551
554
        input = BytesIO()
552
555
        output = BytesIO()
553
556
        flush_calls = []
 
557
 
554
558
        def logging_flush(): flush_calls.append('flush')
555
559
        output.flush = logging_flush
556
560
        vendor = BytesIOSSHVendor(input, output)
600
604
        # really did disconnect.
601
605
        medium.disconnect()
602
606
 
603
 
 
604
607
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
605
608
        # Doing a disconnect on a new (and thus unconnected) TCP medium
606
609
        # does not fail.  It's ok to disconnect an unconnected medium.
611
614
        # Doing a read on a new (and thus unconnected) TCP medium raises
612
615
        # MediumNotConnected.
613
616
        client_medium = medium.SmartTCPClientMedium(None, None, None)
614
 
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
615
 
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
 
617
        self.assertRaises(errors.MediumNotConnected,
 
618
                          client_medium.read_bytes, 0)
 
619
        self.assertRaises(errors.MediumNotConnected,
 
620
                          client_medium.read_bytes, 1)
616
621
 
617
622
    def test_tcp_client_supports__flush(self):
618
623
        # invoking _flush on a TCPClientMedium should do something useful.
692
697
            None, output, 'base')
693
698
        client_medium._current_request = "a"
694
699
        self.assertRaises(errors.TooManyConcurrentRequests,
695
 
            medium.SmartClientStreamMediumRequest, client_medium)
 
700
                          medium.SmartClientStreamMediumRequest, client_medium)
696
701
 
697
702
    def test_finished_read_clears_current_request(self):
698
703
        # calling finished_reading clears the current request from the requests
761
766
        client_medium._connected = True
762
767
        req = client_medium.get_request()
763
768
        self.assertRaises(errors.TooManyConcurrentRequests,
764
 
            client_medium.get_request)
 
769
                          client_medium.get_request)
765
770
        client_medium.reset()
766
771
        # The stream should be reset, marked as disconnected, though ready for
767
772
        # us to make a new request
831
836
                           timeout=4.0):
832
837
        """Create a new SmartServerPipeStreamMedium."""
833
838
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
834
 
            transport, timeout=timeout)
 
839
                                                  transport, timeout=timeout)
835
840
 
836
841
    def create_pipe_context(self, to_server_bytes, transport):
837
842
        """Create a SmartServerSocketStreamMedium.
848
853
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
849
854
        """Initialize a new medium.SmartServerSocketStreamMedium."""
850
855
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
851
 
            timeout=timeout)
 
856
                                                    timeout=timeout)
852
857
 
853
858
    def create_socket_context(self, transport, timeout=4.0):
854
859
        """Create a new SmartServerSocketStreamMedium with default context.
868
873
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
869
874
        server, from_server = self.create_pipe_context(b'hello\n', transport)
870
875
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
871
 
                from_server.write)
 
876
                                                                from_server.write)
872
877
        server._serve_one_request(smart_protocol)
873
878
        self.assertEqual(b'ok\0012\n',
874
879
                         from_server.getvalue())
877
882
        transport = memory.MemoryTransport('memory:///')
878
883
        transport.put_bytes('testfile', b'contents\nof\nfile\n')
879
884
        server, from_server = self.create_pipe_context(b'get\001./testfile\n',
880
 
            transport)
 
885
                                                       transport)
881
886
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
882
 
                from_server.write)
 
887
                                                                from_server.write)
883
888
        server._serve_one_request(smart_protocol)
884
889
        self.assertEqual(b'ok\n'
885
890
                         b'17\n'
895
900
        hpss_path = urlutils.quote_from_bytes(utf8_filename)
896
901
        transport.put_bytes(hpss_path, b'contents\nof\nfile\n')
897
902
        server, from_server = self.create_pipe_context(
898
 
                b'get\001' + hpss_path.encode('ascii') + b'\n', transport)
 
903
            b'get\001' + hpss_path.encode('ascii') + b'\n', transport)
899
904
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
900
 
                from_server.write)
 
905
                                                                from_server.write)
901
906
        server._serve_one_request(smart_protocol)
902
907
        self.assertEqual(b'ok\n'
903
908
                         b'17\n'
1044
1049
        to_server = BytesIO(b'')
1045
1050
        from_server = BytesIO()
1046
1051
        self.closed = False
 
1052
 
1047
1053
        def close():
1048
1054
            self.closed = True
1049
1055
        from_server.close = close
1143
1149
 
1144
1150
    def test_pipe_set_timeout(self):
1145
1151
        server = self.create_pipe_medium(None, None, None,
1146
 
            timeout=1.23)
 
1152
                                         timeout=1.23)
1147
1153
        self.assertEqual(1.23, server._client_timeout)
1148
1154
 
1149
1155
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
1339
1345
    def test_get_error_unexpected(self):
1340
1346
        """Error reported by server with no specific representation"""
1341
1347
        self.overrideEnv('BRZ_NO_SMART_VFS', None)
 
1348
 
1342
1349
        class FlakyTransport(object):
1343
1350
            base = 'a_url'
 
1351
 
1344
1352
            def external_url(self):
1345
1353
                return self.base
 
1354
 
1346
1355
            def get(self, path):
1347
1356
                raise Exception("some random exception from inside server")
1348
1357
 
1441
1450
        # the server thread gets blocked writing content to the client until we
1442
1451
        # finish reading on the client.
1443
1452
        server.backing_transport.put_bytes('bigfile',
1444
 
            b'a'*1024*1024)
 
1453
                                           b'a' * 1024 * 1024)
1445
1454
        client_sock = self.connect_to_server(server)
1446
1455
        self.say_hello(client_sock)
1447
1456
        _, server_side_thread = server._active_connections[0]
1450
1459
            'base', client_sock)
1451
1460
        client_client = client._SmartClient(client_medium)
1452
1461
        resp, response_handler = client_client.call_expecting_body(b'get',
1453
 
            b'bigfile')
 
1462
                                                                   b'bigfile')
1454
1463
        self.assertEqual((b'ok',), resp)
1455
1464
        # Ask the server to stop gracefully, and wait for it.
1456
1465
        server._stop_gracefully()
1467
1476
        self.assertThat(log, DocTestMatches("""\
1468
1477
    INFO  Requested to stop gracefully
1469
1478
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
1470
 
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1479
""", flags=doctest.ELLIPSIS | doctest.REPORT_UDIFF))
1471
1480
 
1472
1481
    def test_stop_gracefully_tells_handlers_to_stop(self):
1473
1482
        server, server_thread = self.make_server()
1600
1609
 
1601
1610
    def test__remote_path(self):
1602
1611
        self.assertEqual(b'/foo/bar',
1603
 
                          self.transport._remote_path('foo/bar'))
 
1612
                         self.transport._remote_path('foo/bar'))
1604
1613
 
1605
1614
    def test_clone_changes_base(self):
1606
1615
        """Cloning transport produces one with a new base location"""
1607
1616
        conn2 = self.transport.clone('subdir')
1608
1617
        self.assertEqual(self.transport.base + 'subdir/',
1609
 
                          conn2.base)
 
1618
                         conn2.base)
1610
1619
 
1611
1620
    def test_open_dir(self):
1612
1621
        """Test changing directory"""
1642
1651
        self.overrideEnv('BRZ_NO_SMART_VFS', None)
1643
1652
        self.start_server(readonly=True)
1644
1653
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1645
 
            'foo')
 
1654
                          'foo')
1646
1655
 
1647
1656
    def test_rename_error_readonly(self):
1648
1657
        """TransportNotPossible should be preserved from the backing transport."""
1670
1679
        """The server_started hook fires when the server is started."""
1671
1680
        self.hook_calls = []
1672
1681
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1673
 
            self.capture_server_call, None)
 
1682
                                                            self.capture_server_call, None)
1674
1683
        self.start_server()
1675
1684
        # at this point, the server will be starting a thread up.
1676
1685
        # there is no indicator at the moment, so bodge it by doing a request.
1678
1687
        # The default test server uses MemoryTransport and that has no external
1679
1688
        # url:
1680
1689
        self.assertEqual([([self.backing_transport.base], self.transport.base)],
1681
 
            self.hook_calls)
 
1690
                         self.hook_calls)
1682
1691
 
1683
1692
    def test_server_started_hook_file(self):
1684
1693
        """The server_started hook fires when the server is started."""
1685
1694
        self.hook_calls = []
1686
1695
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1687
 
            self.capture_server_call, None)
 
1696
                                                            self.capture_server_call, None)
1688
1697
        self.start_server(
1689
1698
            backing_transport=_mod_transport.get_transport_from_path("."))
1690
1699
        # at this point, the server will be starting a thread up.
1694
1703
        # url:
1695
1704
        self.assertEqual([([
1696
1705
            self.backing_transport.base, self.backing_transport.external_url()],
1697
 
             self.transport.base)],
 
1706
            self.transport.base)],
1698
1707
            self.hook_calls)
1699
1708
 
1700
1709
    def test_server_stopped_hook_simple_memory(self):
1701
1710
        """The server_stopped hook fires when the server is stopped."""
1702
1711
        self.hook_calls = []
1703
1712
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1704
 
            self.capture_server_call, None)
 
1713
                                                            self.capture_server_call, None)
1705
1714
        self.start_server()
1706
1715
        result = [([self.backing_transport.base], self.transport.base)]
1707
1716
        # check the stopping message isn't emitted up front.
1718
1727
        """The server_stopped hook fires when the server is stopped."""
1719
1728
        self.hook_calls = []
1720
1729
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1721
 
            self.capture_server_call, None)
 
1730
                                                            self.capture_server_call, None)
1722
1731
        self.start_server(
1723
1732
            backing_transport=_mod_transport.get_transport_from_path("."))
1724
1733
        result = [(
1725
 
            [self.backing_transport.base, self.backing_transport.external_url()]
1726
 
            , self.transport.base)]
 
1734
            [self.backing_transport.base, self.backing_transport.external_url()], self.transport.base)]
1727
1735
        # check the stopping message isn't emitted up front.
1728
1736
        self.assertEqual([], self.hook_calls)
1729
1737
        # nor after a single message
1780
1788
    def test_construct_request_handler(self):
1781
1789
        """Constructing a request handler should be easy and set defaults."""
1782
1790
        handler = _mod_request.SmartServerRequestHandler(None, commands=None,
1783
 
                root_client_path='/')
 
1791
                                                         root_client_path='/')
1784
1792
        self.assertFalse(handler.finished_reading)
1785
1793
 
1786
1794
    def test_hello(self):
1855
1863
        handler.end_of_body()
1856
1864
        self.assertTrue(handler.finished_reading)
1857
1865
        self.assertEqual((b'ShortReadvError', b'./a-file', b'100', b'1', b'0'),
1858
 
            handler.response.args)
 
1866
                         handler.response.args)
1859
1867
        self.assertEqual(None, handler.response.body)
1860
1868
 
1861
1869
 
1868
1876
 
1869
1877
    def test_bzr_https(self):
1870
1878
        # https://bugs.launchpad.net/bzr/+bug/128456
1871
 
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
 
1879
        t = _mod_transport.get_transport_from_url(
 
1880
            'bzr+https://example.com/path')
1872
1881
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1873
1882
        self.assertStartsWith(
1874
1883
            t._http_transport.base,
1909
1918
            'bzr://localhost/', medium=client_medium)
1910
1919
        err = errors.ErrorFromSmartServer((b"ReadOnlyError", ))
1911
1920
        self.assertRaises(errors.TransportNotPossible,
1912
 
            transport._translate_error, err)
 
1921
                          transport._translate_error, err)
1913
1922
 
1914
1923
 
1915
1924
class TestSmartProtocol(tests.TestCase):
1979
1988
            self.client_protocol_class, 'request_marker', None)
1980
1989
 
1981
1990
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1982
 
        requester):
 
1991
                                  requester):
1983
1992
        """Check that smart (de)serialises offsets as expected.
1984
1993
 
1985
1994
        We check both serialisation and deserialisation at the same time
2003
2012
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
2004
2013
            None, _mod_request.request_handlers, '/')
2005
2014
        # GZ 2010-08-10: Cycle with closure affects 4 tests
 
2015
 
2006
2016
        class FakeCommand(_mod_request.SmartServerRequest):
2007
2017
            def do_body(self_cmd, body_bytes):
2008
2018
                self.end_received = True
2017
2027
        return smart_protocol
2018
2028
 
2019
2029
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
2020
 
            input_tuples):
 
2030
                                     input_tuples):
2021
2031
        """Assert that each input_tuple serialises as expected_bytes, and the
2022
2032
        bytes deserialise as expected_tuple.
2023
2033
        """
2031
2041
        # check the decoding of the client smart_protocol from expected_bytes:
2032
2042
        requester, response_handler = self.make_client_protocol(expected_bytes)
2033
2043
        requester.call(b'foo')
2034
 
        self.assertEqual(expected_tuple, response_handler.read_response_tuple())
 
2044
        self.assertEqual(
 
2045
            expected_tuple, response_handler.read_response_tuple())
2035
2046
 
2036
2047
 
2037
2048
class CommonSmartProtocolTestMixin(object):
2040
2051
        requester, response_handler = self.make_client_protocol()
2041
2052
        requester.call(b'hello')
2042
2053
        ex = self.assertRaises(errors.ConnectionReset,
2043
 
            response_handler.read_response_tuple)
 
2054
                               response_handler.read_response_tuple)
2044
2055
        self.assertEqual("Connection closed: "
2045
 
            "Unexpected end of message. Please check connectivity "
2046
 
            "and permissions, and report a bug if problems persist. ",
2047
 
            str(ex))
 
2056
                         "Unexpected end of message. Please check connectivity "
 
2057
                         "and permissions, and report a bug if problems persist. ",
 
2058
                         str(ex))
2048
2059
 
2049
2060
    def test_server_offset_serialisation(self):
2050
2061
        """The Smart protocol serialises offsets as a comma and \n string.
2057
2068
        self.assertOffsetSerialisation([], b'', requester)
2058
2069
        self.assertOffsetSerialisation([(1, 2)], b'1,2', requester)
2059
2070
        self.assertOffsetSerialisation([(10, 40), (0, 5)], b'10,40\n0,5',
2060
 
            requester)
 
2071
                                       requester)
2061
2072
        self.assertOffsetSerialisation([(1, 2), (3, 4), (100, 200)],
2062
 
            b'1,2\n3,4\n100,200', requester)
 
2073
                                       b'1,2\n3,4\n100,200', requester)
2063
2074
 
2064
2075
 
2065
2076
class TestVersionOneFeaturesInProtocolOne(
2066
 
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2077
        TestSmartProtocol, CommonSmartProtocolTestMixin):
2067
2078
    """Tests for version one smart protocol features as implemeted by version
2068
2079
    one."""
2069
2080
 
2114
2125
        mem_transport.put_bytes('foo', b'abcdefghij')
2115
2126
        out_stream = BytesIO()
2116
2127
        smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
2117
 
                out_stream.write)
 
2128
                                                                out_stream.write)
2118
2129
        smart_protocol.accept_bytes(b'readv\x01foo\n3\n3,3done\n')
2119
2130
        self.assertEqual(0, smart_protocol.next_read_size())
2120
2131
        self.assertEqual(b'readv\n3\ndefdone\n', out_stream.getvalue())
2165
2176
        smart_protocol = protocol.SmartServerRequestProtocolOne(
2166
2177
            None, lambda x: None)
2167
2178
        self.assertRaises(AttributeError, smart_protocol._send_response,
2168
 
            _mod_request.SmartServerResponse((b'x',)))
 
2179
                          _mod_request.SmartServerResponse((b'x',)))
2169
2180
 
2170
2181
    def test_query_version(self):
2171
2182
        """query_version on a SmartClientProtocolOne should return a number.
2195
2206
        # protocol.call() can get back tuples of other lengths. A three element
2196
2207
        # tuple should be unpacked as three strings.
2197
2208
        self.assertServerToClientEncoding(b'a\x01b\x0134\n', (b'a', b'b', b'34'),
2198
 
            [(b'a', b'b', b'34')])
 
2209
                                          [(b'a', b'b', b'34')])
2199
2210
 
2200
2211
    def test_client_call_with_body_bytes_uploads(self):
2201
2212
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
2224
2235
        self.assertEqual(expected_bytes, output.getvalue())
2225
2236
 
2226
2237
    def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
2227
 
            server_bytes):
 
2238
                                                                   server_bytes):
2228
2239
        input = BytesIO(server_bytes)
2229
2240
        output = BytesIO()
2230
2241
        client_medium = medium.SmartSimplePipesClientMedium(
2290
2301
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
2291
2302
        smart_protocol.call(b'foo')
2292
2303
        smart_protocol.read_response_tuple(True)
2293
 
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
2294
 
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
2295
 
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
2304
        self.assertEqual(expected_bytes[0:2],
 
2305
                         smart_protocol.read_body_bytes(2))
 
2306
        self.assertEqual(expected_bytes[2:4],
 
2307
                         smart_protocol.read_body_bytes(2))
 
2308
        self.assertEqual(expected_bytes[4:6],
 
2309
                         smart_protocol.read_body_bytes(2))
2296
2310
        self.assertEqual(expected_bytes[6:7], smart_protocol.read_body_bytes())
2297
2311
 
2298
2312
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
2328
2342
 
2329
2343
 
2330
2344
class TestVersionOneFeaturesInProtocolTwo(
2331
 
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2345
        TestSmartProtocol, CommonSmartProtocolTestMixin):
2332
2346
    """Tests for version one smart protocol features as implemeted by version
2333
2347
    two.
2334
2348
    """
2358
2372
        self.assertEqual(b'abc', smart_protocol.in_buffer)
2359
2373
        smart_protocol.accept_bytes(b'\n')
2360
2374
        self.assertEqual(
2361
 
            self.response_marker +
2362
 
            b"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
2375
            self.response_marker
 
2376
            + b"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
2363
2377
            out_stream.getvalue())
2364
2378
        self.assertTrue(smart_protocol._has_dispatched)
2365
2379
        self.assertEqual(0, smart_protocol.next_read_size())
2383
2397
            mem_transport, out_stream.write)
2384
2398
        smart_protocol.accept_bytes(b'readv\x01foo\n3\n3,3done\n')
2385
2399
        self.assertEqual(0, smart_protocol.next_read_size())
2386
 
        self.assertEqual(self.response_marker +
2387
 
                         b'success\nreadv\n3\ndefdone\n',
 
2400
        self.assertEqual(self.response_marker
 
2401
                         + b'success\nreadv\n3\ndefdone\n',
2388
2402
                         out_stream.getvalue())
2389
2403
        self.assertEqual(b'', smart_protocol.unused_data)
2390
2404
        self.assertEqual(b'', smart_protocol.in_buffer)
2401
2415
    def test_accept_excess_bytes_after_body(self):
2402
2416
        # The excess bytes look like the start of another request.
2403
2417
        server_protocol = self.build_protocol_waiting_for_body()
2404
 
        server_protocol.accept_bytes(b'7\nabcdefgdone\n' + self.response_marker)
 
2418
        server_protocol.accept_bytes(
 
2419
            b'7\nabcdefgdone\n' + self.response_marker)
2405
2420
        self.assertTrue(self.end_received)
2406
2421
        self.assertEqual(self.response_marker,
2407
2422
                         server_protocol.unused_data)
2436
2451
        """Ensure that only the Successful/Failed subclasses are used."""
2437
2452
        smart_protocol = self.server_protocol_class(None, lambda x: None)
2438
2453
        self.assertRaises(AttributeError, smart_protocol._send_response,
2439
 
            _mod_request.SmartServerResponse((b'x',)))
 
2454
                          _mod_request.SmartServerResponse((b'x',)))
2440
2455
 
2441
2456
    def test_query_version(self):
2442
2457
        """query_version on a SmartClientProtocolTwo should return a number.
2501
2516
        # read_body_bytes should decode the body bytes from the wire into
2502
2517
        # a response.
2503
2518
        expected_bytes = b"1234567"
2504
 
        server_bytes = (self.response_marker +
2505
 
                        b"success\nok\n7\n1234567done\n")
 
2519
        server_bytes = (self.response_marker
 
2520
                        + b"success\nok\n7\n1234567done\n")
2506
2521
        input = BytesIO(server_bytes)
2507
2522
        output = BytesIO()
2508
2523
        client_medium = medium.SmartSimplePipesClientMedium(
2529
2544
        smart_protocol = self.client_protocol_class(request)
2530
2545
        smart_protocol.call(b'foo')
2531
2546
        smart_protocol.read_response_tuple(True)
2532
 
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
2533
 
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
2534
 
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
2547
        self.assertEqual(expected_bytes[0:2],
 
2548
                         smart_protocol.read_body_bytes(2))
 
2549
        self.assertEqual(expected_bytes[2:4],
 
2550
                         smart_protocol.read_body_bytes(2))
 
2551
        self.assertEqual(expected_bytes[4:6],
 
2552
                         smart_protocol.read_body_bytes(2))
2535
2553
        self.assertEqual(expected_bytes[6:7], smart_protocol.read_body_bytes())
2536
2554
 
2537
2555
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
2553
2571
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
2554
2572
 
2555
2573
    def test_client_read_body_bytes_interrupted_connection(self):
2556
 
        server_bytes = (self.response_marker +
2557
 
                        b"success\nok\n999\nincomplete body")
 
2574
        server_bytes = (self.response_marker
 
2575
                        + b"success\nok\n999\nincomplete body")
2558
2576
        input = BytesIO(server_bytes)
2559
2577
        output = BytesIO()
2560
2578
        client_medium = medium.SmartSimplePipesClientMedium(
2595
2613
    def test_body_stream_serialisation(self):
2596
2614
        stream = [b'chunk one', b'chunk two', b'chunk three']
2597
2615
        self.assertBodyStreamSerialisation(
2598
 
            b'chunked\n' + b'9\nchunk one' + b'9\nchunk two' + b'b\nchunk three' +
2599
 
            b'END\n',
 
2616
            b'chunked\n' + b'9\nchunk one' + b'9\nchunk two' + b'b\nchunk three'
 
2617
            + b'END\n',
2600
2618
            stream)
2601
2619
        self.assertBodyStreamRoundTrips(stream)
2602
2620
 
2615
2633
                  _mod_request.FailedSmartServerResponse(
2616
2634
                      (b'FailureName', b'failure arg'))]
2617
2635
        expected_bytes = (
2618
 
            b'chunked\n' + b'b\nfirst chunk' +
2619
 
            b'ERR\n' + b'b\nFailureName' + b'b\nfailure arg' +
2620
 
            b'END\n')
 
2636
            b'chunked\n' + b'b\nfirst chunk'
 
2637
            + b'ERR\n' + b'b\nFailureName' + b'b\nfailure arg'
 
2638
            + b'END\n')
2621
2639
        self.assertBodyStreamSerialisation(expected_bytes, stream)
2622
2640
        self.assertBodyStreamRoundTrips(stream)
2623
2641
 
2653
2671
        body_header = b'chunked\n'
2654
2672
        two_body_chunks = b"4\n1234" + b"3\n567"
2655
2673
        body_terminator = b"END\n"
2656
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2657
 
                        b"success\nok\n" + body_header + two_body_chunks +
2658
 
                        body_terminator)
 
2674
        server_bytes = (protocol.RESPONSE_VERSION_TWO
 
2675
                        + b"success\nok\n" + body_header + two_body_chunks
 
2676
                        + body_terminator)
2659
2677
        input = BytesIO(server_bytes)
2660
2678
        output = BytesIO()
2661
2679
        client_medium = medium.SmartSimplePipesClientMedium(
2675
2693
        err_chunks = b'a\nerror arg1' + b'4\narg2'
2676
2694
        finish = b'END\n'
2677
2695
        body = body_header + a_body_chunk + err_signal + err_chunks + finish
2678
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2679
 
                        b"success\nok\n" + body)
 
2696
        server_bytes = (protocol.RESPONSE_VERSION_TWO
 
2697
                        + b"success\nok\n" + body)
2680
2698
        input = BytesIO(server_bytes)
2681
2699
        output = BytesIO()
2682
2700
        client_medium = medium.SmartSimplePipesClientMedium(
2694
2712
    def test_streamed_body_bytes_interrupted_connection(self):
2695
2713
        body_header = b'chunked\n'
2696
2714
        incomplete_body_chunk = b"9999\nincomplete chunk"
2697
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
2698
 
                        b"success\nok\n" + body_header + incomplete_body_chunk)
 
2715
        server_bytes = (protocol.RESPONSE_VERSION_TWO
 
2716
                        + b"success\nok\n" + body_header + incomplete_body_chunk)
2699
2717
        input = BytesIO(server_bytes)
2700
2718
        output = BytesIO()
2701
2719
        client_medium = medium.SmartSimplePipesClientMedium(
2724
2742
        the server did not recognise the request.
2725
2743
        """
2726
2744
        server_bytes = (
2727
 
            protocol.RESPONSE_VERSION_TWO +
2728
 
            b"failed\n" +
2729
 
            b"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
2745
            protocol.RESPONSE_VERSION_TWO
 
2746
            + b"failed\n"
 
2747
            + b"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
2730
2748
        input = BytesIO(server_bytes)
2731
2749
        output = BytesIO()
2732
2750
        client_medium = medium.SmartSimplePipesClientMedium(
2755
2773
 
2756
2774
 
2757
2775
class TestVersionOneFeaturesInProtocolThree(
2758
 
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2776
        TestSmartProtocol, CommonSmartProtocolTestMixin):
2759
2777
    """Tests for version one smart protocol features as implemented by version
2760
2778
    three.
2761
2779
    """
2868
2886
        """The protocol can decode a 'bytes' message part."""
2869
2887
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
2870
2888
        smart_protocol.accept_bytes(
2871
 
            b'b' # message part kind
2872
 
            b'\0\0\0\x07' # length prefix
2873
 
            b'payload' # payload
 
2889
            b'b'  # message part kind
 
2890
            b'\0\0\0\x07'  # length prefix
 
2891
            b'payload'  # payload
2874
2892
            )
2875
2893
        self.assertEqual([('bytes', b'payload')], event_log)
2876
2894
 
2878
2896
        """The protocol can decode a 'structure' message part."""
2879
2897
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
2880
2898
        smart_protocol.accept_bytes(
2881
 
            b's' # message part kind
2882
 
            b'\0\0\0\x07' # length prefix
2883
 
            b'l3:ARGe' # ['ARG']
 
2899
            b's'  # message part kind
 
2900
            b'\0\0\0\x07'  # length prefix
 
2901
            b'l3:ARGe'  # ['ARG']
2884
2902
            )
2885
2903
        self.assertEqual([('structure', (b'ARG',))], event_log)
2886
2904
 
2888
2906
        """The protocol can decode a multiple 'bytes' message parts."""
2889
2907
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
2890
2908
        smart_protocol.accept_bytes(
2891
 
            b'b' # message part kind
2892
 
            b'\0\0\0\x05' # length prefix
2893
 
            b'first' # payload
2894
 
            b'b' # message part kind
 
2909
            b'b'  # message part kind
 
2910
            b'\0\0\0\x05'  # length prefix
 
2911
            b'first'  # payload
 
2912
            b'b'  # message part kind
2895
2913
            b'\0\0\0\x06'
2896
2914
            b'second'
2897
2915
            )
2926
2944
 
2927
2945
    def test_interrupted_by_connection_lost(self):
2928
2946
        interrupted_body_stream = (
2929
 
            b'oS' # successful response
2930
 
            b's\0\0\0\x02le' # empty args
 
2947
            b'oS'  # successful response
 
2948
            b's\0\0\0\x02le'  # empty args
2931
2949
            b'b\0\0\xff\xffincomplete chunk')
2932
2950
        response_handler = self.make_response_handler(interrupted_body_stream)
2933
2951
        stream = response_handler.read_streamed_body()
2935
2953
 
2936
2954
    def test_read_body_bytes_interrupted_by_connection_lost(self):
2937
2955
        interrupted_body_stream = (
2938
 
            b'oS' # successful response
2939
 
            b's\0\0\0\x02le' # empty args
 
2956
            b'oS'  # successful response
 
2957
            b's\0\0\0\x02le'  # empty args
2940
2958
            b'b\0\0\xff\xffincomplete chunk')
2941
2959
        response_handler = self.make_response_handler(interrupted_body_stream)
2942
2960
        self.assertRaises(
2944
2962
 
2945
2963
    def test_multiple_bytes_parts(self):
2946
2964
        multiple_bytes_parts = (
2947
 
            b'oS' # successful response
2948
 
            b's\0\0\0\x02le' # empty args
2949
 
            b'b\0\0\0\x0bSome bytes\n' # some bytes
2950
 
            b'b\0\0\0\x0aMore bytes' # more bytes
2951
 
            b'e' # message end
 
2965
            b'oS'  # successful response
 
2966
            b's\0\0\0\x02le'  # empty args
 
2967
            b'b\0\0\0\x0bSome bytes\n'  # some bytes
 
2968
            b'b\0\0\0\x0aMore bytes'  # more bytes
 
2969
            b'e'  # message end
2952
2970
            )
2953
2971
        response_handler = self.make_response_handler(multiple_bytes_parts)
2954
2972
        self.assertEqual(
2979
2997
        """
2980
2998
        from breezy.bzr.smart.message import ConventionalRequestHandler
2981
2999
        request_handler = InstrumentedRequestHandler()
2982
 
        request_handler.response = _mod_request.SuccessfulSmartServerResponse((b'arg', b'arg'))
 
3000
        request_handler.response = _mod_request.SuccessfulSmartServerResponse(
 
3001
            (b'arg', b'arg'))
2983
3002
        responder = FakeResponder()
2984
 
        message_handler = ConventionalRequestHandler(request_handler, responder)
 
3003
        message_handler = ConventionalRequestHandler(
 
3004
            request_handler, responder)
2985
3005
        protocol_decoder = protocol.ProtocolThreeDecoder(message_handler)
2986
3006
        # put decoder in desired state (waiting for message parts)
2987
3007
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
2993
3013
        accept_body method.
2994
3014
        """
2995
3015
        multiple_bytes_parts = (
2996
 
            b's\0\0\0\x07l3:fooe' # args
2997
 
            b'b\0\0\0\x0bSome bytes\n' # some bytes
2998
 
            b'b\0\0\0\x0aMore bytes' # more bytes
2999
 
            b'e' # message end
 
3016
            b's\0\0\0\x07l3:fooe'  # args
 
3017
            b'b\0\0\0\x0bSome bytes\n'  # some bytes
 
3018
            b'b\0\0\0\x0aMore bytes'  # more bytes
 
3019
            b'e'  # message end
3000
3020
            )
3001
3021
        request_handler = self.make_request_handler(multiple_bytes_parts)
3002
3022
        accept_body_calls = [
3007
3027
 
3008
3028
    def test_error_flag_after_body(self):
3009
3029
        body_then_error = (
3010
 
            b's\0\0\0\x07l3:fooe' # request args
3011
 
            b'b\0\0\0\x0bSome bytes\n' # some bytes
3012
 
            b'b\0\0\0\x0aMore bytes' # more bytes
3013
 
            b'oE' # error flag
3014
 
            b's\0\0\0\x07l3:bare' # error args
3015
 
            b'e' # message end
 
3030
            b's\0\0\0\x07l3:fooe'  # request args
 
3031
            b'b\0\0\0\x0bSome bytes\n'  # some bytes
 
3032
            b'b\0\0\0\x0aMore bytes'  # more bytes
 
3033
            b'oE'  # error flag
 
3034
            b's\0\0\0\x07l3:bare'  # error args
 
3035
            b'e'  # message end
3016
3036
            )
3017
3037
        request_handler = self.make_request_handler(body_then_error)
3018
3038
        self.assertEqual(
3039
3059
        # verb+args tuple, it has a single-byte part, which is forbidden.  In
3040
3060
        # fact it has that part twice, to trigger multiple errors.
3041
3061
        invalid_request = (
3042
 
            protocol.MESSAGE_VERSION_THREE +  # protocol version marker
3043
 
            b'\0\0\0\x02de' + # empty headers
3044
 
            b'oX' + # a single byte part: 'X'.  ConventionalRequestHandler will
3045
 
                   # error at this part.
3046
 
            b'oX' + # and again.
3047
 
            b'e' # end of message
 
3062
            protocol.MESSAGE_VERSION_THREE  # protocol version marker
 
3063
            + b'\0\0\0\x02de'  # empty headers
 
3064
            + b'oX' +  # a single byte part: 'X'.  ConventionalRequestHandler will
 
3065
            # error at this part.
 
3066
            b'oX' +  # and again.
 
3067
            b'e'  # end of message
3048
3068
            )
3049
3069
 
3050
3070
        to_server = BytesIO(invalid_request)
3118
3138
        status byte, empty args, no body.
3119
3139
        """
3120
3140
        headers = b'\0\0\0\x02de'  # length-prefixed, bencoded empty dict
3121
 
        response_status = b'oS' # success
3122
 
        args = b's\0\0\0\x02le' # length-prefixed, bencoded empty list
3123
 
        end = b'e' # end marker
 
3141
        response_status = b'oS'  # success
 
3142
        args = b's\0\0\0\x02le'  # length-prefixed, bencoded empty list
 
3143
        end = b'e'  # end marker
3124
3144
        message_bytes = headers + response_status + args + end
3125
3145
        decoder, response_handler = self.make_logging_response_decoder()
3126
3146
        decoder.accept_bytes(message_bytes)
3140
3160
        """
3141
3161
        # Define a simple response that uses all possible message parts.
3142
3162
        headers = b'\0\0\0\x02de'  # length-prefixed, bencoded empty dict
3143
 
        response_status = b'oS' # success
3144
 
        args = b's\0\0\0\x02le' # length-prefixed, bencoded empty list
3145
 
        body = b'b\0\0\0\x04BODY' # a body: 'BODY'
3146
 
        end = b'e' # end marker
 
3163
        response_status = b'oS'  # success
 
3164
        args = b's\0\0\0\x02le'  # length-prefixed, bencoded empty list
 
3165
        body = b'b\0\0\0\x04BODY'  # a body: 'BODY'
 
3166
        end = b'e'  # end marker
3147
3167
        simple_response = headers + response_status + args + body + end
3148
3168
        # Feed the request to the decoder one byte at a time.
3149
3169
        decoder, response_handler = self.make_logging_response_decoder()
3158
3178
        with 'UnknownMethod'.
3159
3179
        """
3160
3180
        headers = b'\0\0\0\x02de'  # length-prefixed, bencoded empty dict
3161
 
        response_status = b'oE' # error flag
 
3181
        response_status = b'oE'  # error flag
3162
3182
        # args: (b'UnknownMethod', 'method-name')
3163
3183
        args = b's\0\0\0\x20l13:UnknownMethod11:method-namee'
3164
 
        end = b'e' # end marker
 
3184
        end = b'e'  # end marker
3165
3185
        message_bytes = headers + response_status + args + end
3166
3186
        decoder, response_handler = self.make_conventional_response_decoder()
3167
3187
        decoder.accept_bytes(message_bytes)
3172
3192
    def test_read_response_tuple_error(self):
3173
3193
        """If the response has an error, it is raised as an exception."""
3174
3194
        headers = b'\0\0\0\x02de'  # length-prefixed, bencoded empty dict
3175
 
        response_status = b'oE' # error
3176
 
        args = b's\0\0\0\x1al9:first arg10:second arge' # two args
3177
 
        end = b'e' # end marker
 
3195
        response_status = b'oE'  # error
 
3196
        args = b's\0\0\0\x1al9:first arg10:second arge'  # two args
 
3197
        end = b'e'  # end marker
3178
3198
        message_bytes = headers + response_status + args + end
3179
3199
        decoder, response_handler = self.make_conventional_response_decoder()
3180
3200
        decoder.accept_bytes(message_bytes)
3204
3224
        requester.set_headers({b'header name': b'header value'})
3205
3225
        requester.call(b'one arg')
3206
3226
        self.assertEqual(
3207
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3208
 
            b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
3209
 
            b's\x00\x00\x00\x0bl7:one arge' # args
3210
 
            b'e', # end
 
3227
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3228
            b'\x00\x00\x00\x1fd11:header name12:header valuee'  # headers
 
3229
            b's\x00\x00\x00\x0bl7:one arge'  # args
 
3230
            b'e',  # end
3211
3231
            output.getvalue())
3212
3232
 
3213
3233
    def test_call_with_body_bytes_smoke_test(self):
3220
3240
        requester.set_headers({b'header name': b'header value'})
3221
3241
        requester.call_with_body_bytes((b'one arg',), b'body bytes')
3222
3242
        self.assertEqual(
3223
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3224
 
            b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
3225
 
            b's\x00\x00\x00\x0bl7:one arge' # args
3226
 
            b'b' # there is a prefixed body
3227
 
            b'\x00\x00\x00\nbody bytes' # the prefixed body
3228
 
            b'e', # end
 
3243
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3244
            b'\x00\x00\x00\x1fd11:header name12:header valuee'  # headers
 
3245
            b's\x00\x00\x00\x0bl7:one arge'  # args
 
3246
            b'b'  # there is a prefixed body
 
3247
            b'\x00\x00\x00\nbody bytes'  # the prefixed body
 
3248
            b'e',  # end
3229
3249
            output.getvalue())
3230
3250
 
3231
3251
    def test_call_writes_just_once(self):
3255
3275
        stream = [b'chunk 1', b'chunk two']
3256
3276
        requester.call_with_body_stream((b'one arg',), stream)
3257
3277
        self.assertEqual(
3258
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3259
 
            b'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
3260
 
            b's\x00\x00\x00\x0bl7:one arge' # args
3261
 
            b'b\x00\x00\x00\x07chunk 1' # a prefixed body chunk
3262
 
            b'b\x00\x00\x00\x09chunk two' # a prefixed body chunk
3263
 
            b'e', # end
 
3278
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3279
            b'\x00\x00\x00\x1fd11:header name12:header valuee'  # headers
 
3280
            b's\x00\x00\x00\x0bl7:one arge'  # args
 
3281
            b'b\x00\x00\x00\x07chunk 1'  # a prefixed body chunk
 
3282
            b'b\x00\x00\x00\x09chunk two'  # a prefixed body chunk
 
3283
            b'e',  # end
3264
3284
            output.getvalue())
3265
3285
 
3266
3286
    def test_call_with_body_stream_empty_stream(self):
3270
3290
        stream = []
3271
3291
        requester.call_with_body_stream((b'one arg',), stream)
3272
3292
        self.assertEqual(
3273
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3274
 
            b'\x00\x00\x00\x02de' # headers
3275
 
            b's\x00\x00\x00\x0bl7:one arge' # args
 
3293
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3294
            b'\x00\x00\x00\x02de'  # headers
 
3295
            b's\x00\x00\x00\x0bl7:one arge'  # args
3276
3296
            # no body chunks
3277
 
            b'e', # end
 
3297
            b'e',  # end
3278
3298
            output.getvalue())
3279
3299
 
3280
3300
    def test_call_with_body_stream_error(self):
3285
3305
        """
3286
3306
        requester, output = self.make_client_encoder_and_output()
3287
3307
        requester.set_headers({})
 
3308
 
3288
3309
        def stream_that_fails():
3289
3310
            yield b'aaa'
3290
3311
            yield b'bbb'
3291
3312
            raise Exception('Boom!')
3292
3313
        self.assertRaises(Exception, requester.call_with_body_stream,
3293
 
            (b'one arg',), stream_that_fails())
 
3314
                          (b'one arg',), stream_that_fails())
3294
3315
        self.assertEqual(
3295
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3296
 
            b'\x00\x00\x00\x02de' # headers
3297
 
            b's\x00\x00\x00\x0bl7:one arge' # args
3298
 
            b'b\x00\x00\x00\x03aaa' # body
3299
 
            b'b\x00\x00\x00\x03bbb' # more body
3300
 
            b'oE' # error flag
3301
 
            b's\x00\x00\x00\x09l5:errore' # error args: ('error',)
3302
 
            b'e', # end
 
3316
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3317
            b'\x00\x00\x00\x02de'  # headers
 
3318
            b's\x00\x00\x00\x0bl7:one arge'  # args
 
3319
            b'b\x00\x00\x00\x03aaa'  # body
 
3320
            b'b\x00\x00\x00\x03bbb'  # more body
 
3321
            b'oE'  # error flag
 
3322
            b's\x00\x00\x00\x09l5:errore'  # error args: ('error',)
 
3323
            b'e',  # end
3303
3324
            output.getvalue())
3304
3325
 
3305
3326
    def test_records_start_of_body_stream(self):
3306
3327
        requester, output = self.make_client_encoder_and_output()
3307
3328
        requester.set_headers({})
3308
3329
        in_stream = [False]
 
3330
 
3309
3331
        def stream_checker():
3310
3332
            self.assertTrue(requester.body_stream_started)
3311
3333
            in_stream[0] = True
3312
3334
            yield b'content'
3313
3335
        flush_called = []
3314
3336
        orig_flush = requester.flush
 
3337
 
3315
3338
        def tracked_flush():
3316
3339
            flush_called.append(in_stream[0])
3317
3340
            if in_stream[0]:
3322
3345
        requester.flush = tracked_flush
3323
3346
        requester.call_with_body_stream((b'one arg',), stream_checker())
3324
3347
        self.assertEqual(
3325
 
            b'bzr message 3 (bzr 1.6)\n' # protocol version
3326
 
            b'\x00\x00\x00\x02de' # headers
3327
 
            b's\x00\x00\x00\x0bl7:one arge' # args
3328
 
            b'b\x00\x00\x00\x07content' # body
 
3348
            b'bzr message 3 (bzr 1.6)\n'  # protocol version
 
3349
            b'\x00\x00\x00\x02de'  # headers
 
3350
            b's\x00\x00\x00\x0bl7:one arge'  # args
 
3351
            b'b\x00\x00\x00\x07content'  # body
3329
3352
            b'e', output.getvalue())
3330
3353
        self.assertEqual([False, True, True], flush_called)
3331
3354
 
3347
3370
 
3348
3371
 
3349
3372
interrupted_body_stream = (
3350
 
    b'oS' # status flag (success)
3351
 
    b's\x00\x00\x00\x08l4:argse' # args struct ('args,')
3352
 
    b'b\x00\x00\x00\x03aaa' # body part ('aaa')
3353
 
    b'b\x00\x00\x00\x03bbb' # body part ('bbb')
3354
 
    b'oE' # status flag (error)
 
3373
    b'oS'  # status flag (success)
 
3374
    b's\x00\x00\x00\x08l4:argse'  # args struct ('args,')
 
3375
    b'b\x00\x00\x00\x03aaa'  # body part ('aaa')
 
3376
    b'b\x00\x00\x00\x03bbb'  # body part ('bbb')
 
3377
    b'oE'  # status flag (error)
3355
3378
    # err struct ('error', 'Exception', 'Boom!')
3356
3379
    b's\x00\x00\x00\x1bl5:error9:Exception5:Boom!e'
3357
 
    b'e' # EOM
 
3380
    b'e'  # EOM
3358
3381
    )
3359
3382
 
3360
3383
 
3382
3405
    def test_send_broken_body_stream(self):
3383
3406
        encoder, out_stream = self.make_response_encoder()
3384
3407
        encoder._headers = {}
 
3408
 
3385
3409
        def stream_that_fails():
3386
3410
            yield b'aaa'
3387
3411
            yield b'bbb'
3391
3415
        encoder.send_response(response)
3392
3416
        expected_response = (
3393
3417
            b'bzr message 3 (bzr 1.6)\n'  # protocol marker
3394
 
            b'\x00\x00\x00\x02de' # headers dict (empty)
 
3418
            b'\x00\x00\x00\x02de'  # headers dict (empty)
3395
3419
            + interrupted_body_stream)
3396
3420
        self.assertEqual(expected_response, out_stream.getvalue())
3397
3421
 
3469
3493
            input, output, 'ignored base')
3470
3494
        smart_client = client._SmartClient(client_medium)
3471
3495
        self.assertRaises(TypeError,
3472
 
            smart_client.call_with_body_bytes, method, args, body)
 
3496
                          smart_client.call_with_body_bytes, method, args, body)
3473
3497
        self.assertEqual(b"", output.getvalue())
3474
3498
        self.assertEqual(None, client_medium._current_request)
3475
3499
 
3478
3502
 
3479
3503
    def test_call_with_body_bytes_unicode_args(self):
3480
3504
        self.assertCallDoesNotBreakMedium(b'method', (u'args',), b'body')
3481
 
        self.assertCallDoesNotBreakMedium(b'method', (b'arg1', u'arg2'), b'body')
 
3505
        self.assertCallDoesNotBreakMedium(
 
3506
            b'method', (b'arg1', u'arg2'), b'body')
3482
3507
 
3483
3508
    def test_call_with_body_bytes_unicode_body(self):
3484
3509
        self.assertCallDoesNotBreakMedium(b'method', (b'args',), u'body')
3575
3600
 
3576
3601
    def disconnect(self):
3577
3602
        if self._mock_request._read_bytes:
3578
 
            self._assertEvent(('read response', self._mock_request._read_bytes))
 
3603
            self._assertEvent(
 
3604
                ('read response', self._mock_request._read_bytes))
3579
3605
            self._mock_request._read_bytes = b''
3580
3606
        self._assertEvent('disconnect')
3581
3607
 
3632
3658
        smart_client = client._SmartClient(medium, headers={})
3633
3659
        message_start = protocol.MESSAGE_VERSION_THREE + b'\x00\x00\x00\x02de'
3634
3660
        medium.expect_request(
3635
 
            message_start +
3636
 
            b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
3661
            message_start
 
3662
            + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
3637
3663
            message_start + b's\0\0\0\x13l14:response valueee')
3638
3664
        result = smart_client.call(b'method-name', b'arg 1', b'arg 2')
3639
3665
        # The call succeeded without raising any exceptions from the mock
3656
3682
        # First the client should send a v3 request, but the server will reply
3657
3683
        # with a v2 error.
3658
3684
        medium.expect_request(
3659
 
            b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
3660
 
            b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
3685
            b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de'
 
3686
            + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
3661
3687
            b'bzr response 2\nfailed\n\n')
3662
3688
        # So then the client should disconnect to reset the connection, because
3663
3689
        # the client needs to assume the server cannot read any further
3694
3720
        unknown_protocol_bytes = b'Unknown protocol!'
3695
3721
        # The client will try v3 and v2 before eventually giving up.
3696
3722
        medium.expect_request(
3697
 
            b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
3698
 
            b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
3723
            b'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de'
 
3724
            + b's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
3699
3725
            unknown_protocol_bytes)
3700
3726
        medium.expect_disconnect()
3701
3727
        medium.expect_request(
3720
3746
        # Issue a request that gets an error reply in a non-default protocol
3721
3747
        # version.
3722
3748
        medium.expect_request(
3723
 
            message_start +
3724
 
            b's\x00\x00\x00\x10l11:method-nameee',
 
3749
            message_start
 
3750
            + b's\x00\x00\x00\x10l11:method-nameee',
3725
3751
            b'bzr response 2\nfailed\n\n')
3726
3752
        medium.expect_disconnect()
3727
3753
        medium.expect_request(
3749
3775
        """
3750
3776
        smart_client = client._SmartClient('dummy medium')
3751
3777
        self.assertEqual(
3752
 
                breezy.__version__.encode('utf-8'),
3753
 
                smart_client._headers[b'Software version'])
 
3778
            breezy.__version__.encode('utf-8'),
 
3779
            smart_client._headers[b'Software version'])
3754
3780
        # XXX: need a test that smart_client._headers is passed to the request
3755
3781
        # encoder.
3756
3782
 
3761
3787
        response_io = BytesIO(response)
3762
3788
        output = BytesIO()
3763
3789
        vendor = FirstRejectedBytesIOSSHVendor(response_io, output,
3764
 
                    fail_at_write=fail_at_write)
 
3790
                                               fail_at_write=fail_at_write)
3765
3791
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
3766
3792
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
3767
3793
        smart_client = client._SmartClient(client_medium, headers={})
3770
3796
    def make_response(self, args, body=None, body_stream=None):
3771
3797
        response_io = BytesIO()
3772
3798
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
3773
 
            body_stream=body_stream)
 
3799
                                                              body_stream=body_stream)
3774
3800
        responder = protocol.ProtocolThreeResponder(response_io.write)
3775
3801
        responder.send_response(response)
3776
3802
        return response_io.getvalue()
3780
3806
        output, vendor, smart_client = self.make_client_with_failing_medium(
3781
3807
            fail_at_write=False, response=response)
3782
3808
        smart_request = client._SmartClientRequest(smart_client, b'append',
3783
 
            (b'foo', b''), body=b'content\n')
 
3809
                                                   (b'foo', b''), body=b'content\n')
3784
3810
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
3785
3811
 
3786
3812
    def test__call_retries_get_bytes(self):
3788
3814
        output, vendor, smart_client = self.make_client_with_failing_medium(
3789
3815
            fail_at_write=False, response=response)
3790
3816
        smart_request = client._SmartClientRequest(smart_client, b'get',
3791
 
            (b'foo',))
 
3817
                                                   (b'foo',))
3792
3818
        response, response_handler = smart_request._call(3)
3793
3819
        self.assertEqual((b'ok',), response)
3794
3820
        self.assertEqual(b'content\n', response_handler.read_body_bytes())
3799
3825
        output, vendor, smart_client = self.make_client_with_failing_medium(
3800
3826
            fail_at_write=False, response=response)
3801
3827
        smart_request = client._SmartClientRequest(smart_client, b'get',
3802
 
            (b'foo',))
 
3828
                                                   (b'foo',))
3803
3829
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
3804
3830
 
3805
3831
    def test__send_no_retry_pipes(self):
3806
3832
        client_read, server_write = create_file_pipes()
3807
3833
        server_read, client_write = create_file_pipes()
3808
3834
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
3809
 
            client_write, base='/')
 
3835
                                                            client_write, base='/')
3810
3836
        smart_client = client._SmartClient(client_medium)
3811
3837
        smart_request = client._SmartClientRequest(smart_client,
3812
 
            b'hello', ())
 
3838
                                                   b'hello', ())
3813
3839
        # Close the server side
3814
3840
        server_read.close()
3815
3841
        encoder, response_handler = smart_request._construct_protocol(3)
3816
3842
        self.assertRaises(errors.ConnectionReset,
3817
 
            smart_request._send_no_retry, encoder)
 
3843
                          smart_request._send_no_retry, encoder)
3818
3844
 
3819
3845
    def test__send_read_response_sockets(self):
3820
3846
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
3832
3858
        # connection until we try to read again.
3833
3859
        handler = smart_request._send(3)
3834
3860
        self.assertRaises(errors.ConnectionReset,
3835
 
            handler.read_response_tuple, expect_body=False)
 
3861
                          handler.read_response_tuple, expect_body=False)
3836
3862
 
3837
3863
    def test__send_retries_on_write(self):
3838
3864
        output, vendor, smart_client = self.make_client_with_failing_medium()
3839
3865
        smart_request = client._SmartClientRequest(smart_client, b'hello', ())
3840
3866
        handler = smart_request._send(3)
3841
 
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol
 
3867
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n'  # protocol
3842
3868
                         b'\x00\x00\x00\x02de'   # empty headers
3843
3869
                         b's\x00\x00\x00\tl5:helloee',
3844
3870
                         output.getvalue())
3848
3874
             ('close',),
3849
3875
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3850
3876
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3851
 
            ],
 
3877
             ],
3852
3878
            vendor.calls)
3853
3879
 
3854
3880
    def test__send_doesnt_retry_read_failure(self):
3856
3882
            fail_at_write=False)
3857
3883
        smart_request = client._SmartClientRequest(smart_client, b'hello', ())
3858
3884
        handler = smart_request._send(3)
3859
 
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol
 
3885
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n'  # protocol
3860
3886
                         b'\x00\x00\x00\x02de'   # empty headers
3861
3887
                         b's\x00\x00\x00\tl5:helloee',
3862
3888
                         output.getvalue())
3863
3889
        self.assertEqual(
3864
3890
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3865
3891
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3866
 
            ],
 
3892
             ],
3867
3893
            vendor.calls)
3868
3894
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
3869
3895
 
3870
3896
    def test__send_request_retries_body_stream_if_not_started(self):
3871
3897
        output, vendor, smart_client = self.make_client_with_failing_medium()
3872
3898
        smart_request = client._SmartClientRequest(smart_client, b'hello', (),
3873
 
            body_stream=[b'a', b'b'])
 
3899
                                                   body_stream=[b'a', b'b'])
3874
3900
        response_handler = smart_request._send(3)
3875
3901
        # We connect, get disconnected, and notice before consuming the stream,
3876
3902
        # so we try again one time and succeed.
3880
3906
             ('close',),
3881
3907
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3882
3908
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3883
 
            ],
 
3909
             ],
3884
3910
            vendor.calls)
3885
 
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol
 
3911
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n'  # protocol
3886
3912
                         b'\x00\x00\x00\x02de'   # empty headers
3887
3913
                         b's\x00\x00\x00\tl5:helloe'
3888
3914
                         b'b\x00\x00\x00\x01a'
3897
3923
 
3898
3924
        class FailAfterFirstWrite(BytesIO):
3899
3925
            """Allow one 'write' call to pass, fail the rest"""
 
3926
 
3900
3927
            def __init__(self):
3901
3928
                BytesIO.__init__(self)
3902
3929
                self._first = True
3909
3936
        output = FailAfterFirstWrite()
3910
3937
 
3911
3938
        vendor = FirstRejectedBytesIOSSHVendor(response, output,
3912
 
            fail_at_write=False)
 
3939
                                               fail_at_write=False)
3913
3940
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
3914
3941
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
3915
3942
        smart_client = client._SmartClient(client_medium, headers={})
3916
3943
        smart_request = client._SmartClientRequest(smart_client, b'hello', (),
3917
 
            body_stream=[b'a', b'b'])
 
3944
                                                   body_stream=[b'a', b'b'])
3918
3945
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
3919
3946
        # We connect, and manage to get to the point that we start consuming
3920
3947
        # the body stream. The next write fails, so we just stop.
3922
3949
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3923
3950
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3924
3951
             ('close',),
3925
 
            ],
 
3952
             ],
3926
3953
            vendor.calls)
3927
 
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n' # protocol
 
3954
        self.assertEqual(b'bzr message 3 (bzr 1.6)\n'  # protocol
3928
3955
                         b'\x00\x00\x00\x02de'   # empty headers
3929
3956
                         b's\x00\x00\x00\tl5:helloe',
3930
3957
                         output.getvalue())
3938
3965
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3939
3966
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3940
3967
             ('close',),
3941
 
            ],
 
3968
             ],
3942
3969
            vendor.calls)
3943
3970
 
3944
3971
 
4134
4161
        finish = b'END\n'
4135
4162
        combined = chunk_length + chunk_content + finish
4136
4163
        for i in range(len(combined)):
4137
 
            decoder.accept_bytes(combined[i:i+1])
 
4164
            decoder.accept_bytes(combined[i:i + 1])
4138
4165
        self.assertTrue(decoder.finished_reading)
4139
4166
        self.assertEqual(chunk_content, decoder.read_next_chunk())
4140
4167
        self.assertEqual(b'', decoder.unused_data)
4215
4242
        response = _mod_request.FailedSmartServerResponse((b'foo', b'bar'))
4216
4243
        self.assertEqual((b'foo', b'bar'), response.args)
4217
4244
        self.assertEqual(None, response.body)
4218
 
        response = _mod_request.FailedSmartServerResponse((b'foo', b'bar'), b'bytes')
 
4245
        response = _mod_request.FailedSmartServerResponse(
 
4246
            (b'foo', b'bar'), b'bytes')
4219
4247
        self.assertEqual((b'foo', b'bar'), response.args)
4220
4248
        self.assertEqual(b'bytes', response.body)
4221
4249
        # repr(response) doesn't trigger exceptions.
4231
4259
    def __init__(self):
4232
4260
        self.written_request = None
4233
4261
        self._current_request = None
 
4262
 
4234
4263
    def send_http_smart_request(self, bytes):
4235
4264
        self.written_request = bytes
4236
4265
        return None
4313
4342
        r = t._redirected_to('http://www.example.com/foo',
4314
4343
                             'bzr://www.example.com/foo')
4315
4344
        self.assertNotEqual(type(r), type(t))
4316
 
 
4317