/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Andrew Bennetts
  • Date: 2008-02-11 06:56:18 UTC
  • mto: (3245.4.1 Version 3 implementation.)
  • mto: This revision was merged to the branch mainline in revision 3428.
  • Revision ID: andrew.bennetts@canonical.com-20080211065618-5z8nwo6oik8dgdkt
Checkpoint first rough cut of SmartServerRequestProtocolThree, this implementation reuses the _StatefulDecoder class.  Plus some attempts to start tidying the smart protocol tests.

Show diffs side-by-side

added added

removed removed

Lines of Context:
21
21
import os
22
22
import socket
23
23
import threading
 
24
import unittest
24
25
import urllib2
25
26
 
26
27
from bzrlib import (
1346
1347
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1347
1348
        self.smart_server_request = request.SmartServerRequestHandler(
1348
1349
            None, request.request_handlers)
 
1350
        self.response_marker = getattr(
 
1351
            self.client_protocol_class, 'response_marker', None)
 
1352
        self.request_marker = getattr(
 
1353
            self.client_protocol_class, 'request_marker', None)
1349
1354
 
1350
1355
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1351
1356
        client):
1410
1415
 
1411
1416
class CommonSmartProtocolTestMixin(object):
1412
1417
 
 
1418
    def test_server_offset_serialisation(self):
 
1419
        """The Smart protocol serialises offsets as a comma and \n string.
 
1420
 
 
1421
        We check a number of boundary cases are as expected: empty, one offset,
 
1422
        one with the order of reads not increasing (an out of order read), and
 
1423
        one that should coalesce.
 
1424
        """
 
1425
        self.assertOffsetSerialisation([], '', self.client_protocol)
 
1426
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
 
1427
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
 
1428
            self.client_protocol)
 
1429
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
 
1430
            '1,2\n3,4\n100,200', self.client_protocol)
 
1431
 
1413
1432
    def test_errors_are_logged(self):
1414
1433
        """If an error occurs during testing, it is logged to the test log."""
 
1434
        # XXX: hmm, only errors from the request handler get logged, other
 
1435
        # protocol errors don't.   I guess this is trying to test for internal
 
1436
        # logic errors (unexpected "internal server errors", ala HTTP 500)
 
1437
        # rather than bad requests from the client?  The behaviour here needs
 
1438
        # clarification.
1415
1439
        out_stream = StringIO()
1416
1440
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1417
1441
        # This triggers a "bad request" error.
1434
1458
            "(and try -Dhpss if further diagnosis is required)", str(ex))
1435
1459
 
1436
1460
 
1437
 
class TestSmartProtocolOne(TestSmartProtocol, CommonSmartProtocolTestMixin):
1438
 
    """Tests for the smart protocol version one."""
 
1461
class TestVersionOneFeaturesInProtocolOne(
 
1462
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1463
    """Tests for version one smart protocol features as implemeted by version
 
1464
    one."""
1439
1465
 
1440
1466
    client_protocol_class = protocol.SmartClientRequestProtocolOne
1441
1467
    server_protocol_class = protocol.SmartServerRequestProtocolOne
1454
1480
        request = client_medium.get_request()
1455
1481
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1456
1482
 
1457
 
    def test_server_offset_serialisation(self):
1458
 
        """The Smart protocol serialises offsets as a comma and \n string.
1459
 
 
1460
 
        We check a number of boundary cases are as expected: empty, one offset,
1461
 
        one with the order of reads not increasing (an out of order read), and
1462
 
        one that should coalesce.
1463
 
        """
1464
 
        self.assertOffsetSerialisation([], '', self.client_protocol)
1465
 
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1466
 
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1467
 
            self.client_protocol)
1468
 
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1469
 
            '1,2\n3,4\n100,200', self.client_protocol)
1470
 
 
1471
1483
    def test_accept_bytes_of_bad_request_to_protocol(self):
1472
1484
        out_stream = StringIO()
1473
1485
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1655
1667
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1656
1668
 
1657
1669
 
1658
 
class TestSmartProtocolTwo(TestSmartProtocol, CommonSmartProtocolTestMixin):
1659
 
    """Tests for the smart protocol version two.
1660
 
 
1661
 
    This test case is mostly the same as TestSmartProtocolOne.
 
1670
class TestVersionOneFeaturesInProtocolTwo(
 
1671
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1672
    """Tests for version one smart protocol features as implemeted by version
 
1673
    two.
1662
1674
    """
1663
1675
 
1664
1676
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
1678
1690
        request = client_medium.get_request()
1679
1691
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1680
1692
 
1681
 
    def test_server_offset_serialisation(self):
1682
 
        """The Smart protocol serialises offsets as a comma and \n string.
1683
 
 
1684
 
        We check a number of boundary cases are as expected: empty, one offset,
1685
 
        one with the order of reads not increasing (an out of order read), and
1686
 
        one that should coalesce.
1687
 
        """
1688
 
        self.assertOffsetSerialisation([], '', self.client_protocol)
1689
 
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1690
 
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1691
 
            self.client_protocol)
1692
 
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1693
 
            '1,2\n3,4\n100,200', self.client_protocol)
1694
 
 
1695
1693
    def test_accept_bytes_of_bad_request_to_protocol(self):
1696
1694
        out_stream = StringIO()
1697
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1698
 
            None, out_stream.write)
 
1695
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1699
1696
        smart_protocol.accept_bytes('abc')
1700
1697
        self.assertEqual('abc', smart_protocol.in_buffer)
1701
1698
        smart_protocol.accept_bytes('\n')
1702
1699
        self.assertEqual(
1703
 
            protocol.RESPONSE_VERSION_TWO +
 
1700
            self.response_marker +
1704
1701
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1705
1702
            out_stream.getvalue())
1706
1703
        self.assertTrue(smart_protocol.has_dispatched)
1721
1718
        mem_transport = memory.MemoryTransport()
1722
1719
        mem_transport.put_bytes('foo', 'abcdefghij')
1723
1720
        out_stream = StringIO()
1724
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
1725
 
                out_stream.write)
 
1721
        smart_protocol = self.server_protocol_class(
 
1722
            mem_transport, out_stream.write)
1726
1723
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1727
1724
        self.assertEqual(0, smart_protocol.next_read_size())
1728
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO +
 
1725
        self.assertEqual(self.response_marker +
1729
1726
                         'success\nreadv\n3\ndefdone\n',
1730
1727
                         out_stream.getvalue())
1731
1728
        self.assertEqual('', smart_protocol.excess_buffer)
1733
1730
 
1734
1731
    def test_accept_excess_bytes_are_preserved(self):
1735
1732
        out_stream = StringIO()
1736
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1737
 
            None, out_stream.write)
 
1733
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1738
1734
        smart_protocol.accept_bytes('hello\nhello\n')
1739
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
 
1735
        self.assertEqual(self.response_marker + "success\nok\x012\n",
1740
1736
                         out_stream.getvalue())
1741
1737
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
1742
1738
        self.assertEqual("", smart_protocol.in_buffer)
1744
1740
    def test_accept_excess_bytes_after_body(self):
1745
1741
        # The excess bytes look like the start of another request.
1746
1742
        server_protocol = self.build_protocol_waiting_for_body()
1747
 
        server_protocol.accept_bytes(
1748
 
            '7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
 
1743
        server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
1749
1744
        self.assertTrue(self.end_received)
1750
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO,
 
1745
        self.assertEqual(self.response_marker,
1751
1746
                         server_protocol.excess_buffer)
1752
1747
        self.assertEqual("", server_protocol.in_buffer)
1753
1748
        server_protocol.accept_bytes('Y')
1754
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
 
1749
        self.assertEqual(self.response_marker + "Y",
1755
1750
                         server_protocol.excess_buffer)
1756
1751
        self.assertEqual("", server_protocol.in_buffer)
1757
1752
 
1758
1753
    def test_accept_excess_bytes_after_dispatch(self):
1759
1754
        out_stream = StringIO()
1760
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1761
 
            None, out_stream.write)
 
1755
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1762
1756
        smart_protocol.accept_bytes('hello\n')
1763
 
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
 
1757
        self.assertEqual(self.response_marker + "success\nok\x012\n",
1764
1758
                         out_stream.getvalue())
1765
 
        smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
1766
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
 
1759
        smart_protocol.accept_bytes(self.request_marker + 'hel')
 
1760
        self.assertEqual(self.request_marker + "hel",
1767
1761
                         smart_protocol.excess_buffer)
1768
1762
        smart_protocol.accept_bytes('lo\n')
1769
 
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
 
1763
        self.assertEqual(self.request_marker + "hello\n",
1770
1764
                         smart_protocol.excess_buffer)
1771
1765
        self.assertEqual("", smart_protocol.in_buffer)
1772
1766
 
1773
1767
    def test__send_response_sets_finished_reading(self):
1774
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1775
 
            None, lambda x: None)
 
1768
        smart_protocol = self.server_protocol_class(None, lambda x: None)
1776
1769
        self.assertEqual(1, smart_protocol.next_read_size())
1777
1770
        smart_protocol._send_response(
1778
1771
            request.SuccessfulSmartServerResponse(('x',)))
1780
1773
 
1781
1774
    def test__send_response_errors_with_base_response(self):
1782
1775
        """Ensure that only the Successful/Failed subclasses are used."""
1783
 
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
1784
 
            None, lambda x: None)
 
1776
        smart_protocol = self.server_protocol_class(None, lambda x: None)
1785
1777
        self.assertRaises(AttributeError, smart_protocol._send_response,
1786
1778
            request.SmartServerResponse(('x',)))
1787
1779
 
1788
1780
    def test_query_version(self):
1789
 
        """query_version on a SmartClientProtocolTwo should return a number.
 
1781
        """query_version on a SmartClientProtocolThree should return a number.
1790
1782
        
1791
1783
        The protocol provides the query_version because the domain level clients
1792
1784
        may all need to be able to probe for capabilities.
1795
1787
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1796
1788
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
1797
1789
        # the error if the response is a non-understood version.
1798
 
        input = StringIO(protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
1790
        input = StringIO(self.response_marker + 'success\nok\x012\n')
1799
1791
        output = StringIO()
1800
1792
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1801
1793
        request = client_medium.get_request()
1802
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1794
        smart_protocol = self.client_protocol_class(request)
1803
1795
        self.assertEqual(2, smart_protocol.query_version())
1804
1796
 
1805
1797
    def test_client_call_empty_response(self):
1807
1799
        # when the parsed line is an empty line, and results in a tuple with
1808
1800
        # one element - an empty string.
1809
1801
        self.assertServerToClientEncoding(
1810
 
            protocol.RESPONSE_VERSION_TWO + 'success\n\n', ('', ), [(), ('', )])
 
1802
            self.response_marker + 'success\n\n', ('', ), [(), ('', )])
1811
1803
 
1812
1804
    def test_client_call_three_element_response(self):
1813
1805
        # protocol.call() can get back tuples of other lengths. A three element
1814
1806
        # tuple should be unpacked as three strings.
1815
1807
        self.assertServerToClientEncoding(
1816
 
            protocol.RESPONSE_VERSION_TWO + 'success\na\x01b\x0134\n',
 
1808
            self.response_marker + 'success\na\x01b\x0134\n',
1817
1809
            ('a', 'b', '34'),
1818
1810
            [('a', 'b', '34')])
1819
1811
 
1820
1812
    def test_client_call_with_body_bytes_uploads(self):
1821
1813
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1822
1814
        # wire.
1823
 
        expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
 
1815
        expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
1824
1816
        input = StringIO("\n")
1825
1817
        output = StringIO()
1826
1818
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1827
1819
        request = client_medium.get_request()
1828
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1820
        smart_protocol = self.client_protocol_class(request)
1829
1821
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1830
1822
        self.assertEqual(expected_bytes, output.getvalue())
1831
1823
 
1832
1824
    def test_client_call_with_body_readv_array(self):
1833
1825
        # protocol.call_with_upload should encode the readv array and then
1834
1826
        # length-prefix the bytes onto the wire.
1835
 
        expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
 
1827
        expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
1836
1828
        input = StringIO("\n")
1837
1829
        output = StringIO()
1838
1830
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1839
1831
        request = client_medium.get_request()
1840
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1832
        smart_protocol = self.client_protocol_class(request)
1841
1833
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1842
1834
        self.assertEqual(expected_bytes, output.getvalue())
1843
1835
 
1845
1837
        # read_body_bytes should decode the body bytes from the wire into
1846
1838
        # a response.
1847
1839
        expected_bytes = "1234567"
1848
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
1840
        server_bytes = (self.response_marker +
1849
1841
                        "success\nok\n7\n1234567done\n")
1850
1842
        input = StringIO(server_bytes)
1851
1843
        output = StringIO()
1852
1844
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1853
1845
        request = client_medium.get_request()
1854
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1846
        smart_protocol = self.client_protocol_class(request)
1855
1847
        smart_protocol.call('foo')
1856
1848
        smart_protocol.read_response_tuple(True)
1857
1849
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1863
1855
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
1864
1856
        # that.
1865
1857
        expected_bytes = "1234567"
1866
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
1867
 
                        "success\nok\n7\n1234567done\n")
 
1858
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
1868
1859
        input = StringIO(server_bytes)
1869
1860
        output = StringIO()
1870
1861
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1871
1862
        request = client_medium.get_request()
1872
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1863
        smart_protocol = self.client_protocol_class(request)
1873
1864
        smart_protocol.call('foo')
1874
1865
        smart_protocol.read_response_tuple(True)
1875
1866
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
1880
1871
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1881
1872
        # cancelling the expected body needs to finish the request, but not
1882
1873
        # read any more bytes.
1883
 
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
1884
 
                        "success\nok\n7\n1234567done\n")
 
1874
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
1885
1875
        input = StringIO(server_bytes)
1886
1876
        output = StringIO()
1887
1877
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1888
1878
        request = client_medium.get_request()
1889
 
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1879
        smart_protocol = self.client_protocol_class(request)
1890
1880
        smart_protocol.call('foo')
1891
1881
        smart_protocol.read_response_tuple(True)
1892
1882
        smart_protocol.cancel_read_body()
1893
 
        self.assertEqual(len(protocol.RESPONSE_VERSION_TWO + 'success\nok\n'),
 
1883
        self.assertEqual(len(self.response_marker + 'success\nok\n'),
1894
1884
                         input.tell())
1895
1885
        self.assertRaises(
1896
1886
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1897
1887
 
1898
1888
 
1899
 
class TestSmartProtocolTwoSpecifics(TestSmartProtocol):
1900
 
    """Tests for aspects of smart protocol version two that are unique to
1901
 
    version two.
1902
 
 
1903
 
    Thus tests involving body streams and success/failure markers belong here.
1904
 
    """
1905
 
 
1906
 
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
1907
 
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
1889
class TestSmartProtocolTwoSpecificsMixin(object):
1908
1890
 
1909
1891
    def assertBodyStreamSerialisation(self, expected_serialisation,
1910
1892
                                      body_stream):
2038
2020
        self.assertEqual(True, smart_protocol.response_status)
2039
2021
 
2040
2022
 
2041
 
class TestProtocolTestCoverage(tests.TestCase):
2042
 
 
2043
 
    def assertSetEqual(self, set_a, set_b):
2044
 
        if set_a != set_b:
2045
 
            missing_from_a = sorted(set_b - set_a)
2046
 
            missing_from_b = sorted(set_a - set_b)
2047
 
            raise self.failureException(
2048
 
                'Sets not equal.\na is missing: %r\nb is missing: %r'
2049
 
                % (missing_from_a, missing_from_b))
2050
 
 
2051
 
    def remove_version_specific_tests(self, test_names):
2052
 
        return [name for name in test_names
2053
 
                if not name.startswith('test_construct_version_')]
2054
 
    
2055
 
    def test_ensure_consistent_coverage(self):
2056
 
        """We should be testing the same set of conditions for all protocol
2057
 
        implementations.
2058
 
 
2059
 
        The implementations of those tests may differ (so we can't use simple
2060
 
        test parameterisation to keep the tests synchronised), so this test is
2061
 
        to ensure that the set of test methods names executed on the
2062
 
        TestSmartProtocol{One,Two} classes are the same.
 
2023
class TestSmartProtocolTwoSpecifics(
 
2024
        TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
 
2025
    """Tests for aspects of smart protocol version two that are unique to
 
2026
    version two.
 
2027
 
 
2028
    Thus tests involving body streams and success/failure markers belong here.
 
2029
    """
 
2030
 
 
2031
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
2032
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
2033
 
 
2034
 
 
2035
class TestVersionOneFeaturesInProtocolThree(
 
2036
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2037
    """Tests for version one smart protocol features as implemented by version
 
2038
    three.
 
2039
    """
 
2040
 
 
2041
    client_protocol_class = protocol.SmartClientRequestProtocolThree
 
2042
    server_protocol_class = protocol.SmartServerRequestProtocolThree
 
2043
 
 
2044
    def test_construct_version_three_server_protocol(self):
 
2045
        smart_protocol = protocol.SmartServerRequestProtocolThree(None, None)
 
2046
        self.assertEqual('', smart_protocol.excess_buffer)
 
2047
        self.assertEqual('', smart_protocol._in_buffer)
 
2048
        self.assertFalse(smart_protocol.has_dispatched)
 
2049
        # The protocol starts by scanning one byte a time, because it needs to
 
2050
        # see a newline before it can determine the protocol version.
 
2051
        self.assertEqual(1, smart_protocol.next_read_size())
 
2052
 
 
2053
    def test_construct_version_three_client_protocol(self):
 
2054
        # we can construct a client protocol from a client medium request
 
2055
        output = StringIO()
 
2056
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
2057
        request = client_medium.get_request()
 
2058
        client_protocol = protocol.SmartClientRequestProtocolThree(request)
 
2059
 
 
2060
 
 
2061
class TestServerProtocolThree(TestSmartProtocol):
 
2062
    """Tests for v3 of the server-side protocol."""
 
2063
 
 
2064
    client_protocol_class = protocol.SmartClientRequestProtocolThree
 
2065
    server_protocol_class = protocol.SmartServerRequestProtocolThree
 
2066
 
 
2067
    def test_trivial_request(self):
 
2068
        """Smoke test for the simplest possible v3 request: no headers, single
 
2069
        argument, no body.
2063
2070
        """
2064
 
        import unittest
2065
 
        loader = unittest.TestLoader()
2066
 
        names = loader.getTestCaseNames(TestSmartProtocolOne)
2067
 
        protocol1_tests = set(self.remove_version_specific_tests(names))
2068
 
        names = loader.getTestCaseNames(TestSmartProtocolTwo)
2069
 
        protocol2_tests = set(self.remove_version_specific_tests(names))
2070
 
        self.assertSetEqual(protocol1_tests, protocol2_tests)
 
2071
        output = StringIO()
 
2072
        protocol_version = "bzr request 3\n"
 
2073
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2074
        args = '\0\0\0\x07l3:ARGe' # length-prefixed, bencoded list: ['ARG']
 
2075
        body = 'n'
 
2076
        request_bytes = headers + args + body
 
2077
        smart_protocol = self.server_protocol_class(None, output.write)
 
2078
        smart_protocol.accept_bytes(request_bytes)
 
2079
        self.assertEqual(0, smart_protocol.next_read_size())
 
2080
        self.assertEqual('', smart_protocol.excess_buffer)
 
2081
 
 
2082
    def make_protocol_expecting_body(self):
 
2083
        output = StringIO()
 
2084
        protocol_version = "bzr request 3\n"
 
2085
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2086
        args = '\0\0\0\x07l3:ARGe' # length-prefixed, bencoded list: ['ARG']
 
2087
        request_bytes = headers + args
 
2088
        smart_protocol = self.server_protocol_class(None, output.write)
 
2089
        smart_protocol.accept_bytes(request_bytes)
 
2090
        return smart_protocol
 
2091
 
 
2092
    def test_no_body(self):
 
2093
        smart_protocol = self.make_protocol_expecting_body()
 
2094
        smart_protocol.request_handler = InstrumentedRequestHandler()
 
2095
        body = (
 
2096
            'n' # body kind
 
2097
            )
 
2098
        smart_protocol.accept_bytes(body)
 
2099
        self.assertEqual(
 
2100
            [('no_body_received',)],
 
2101
            smart_protocol.request_handler.calls)
 
2102
 
 
2103
    def test_prefixed_body(self):
 
2104
        smart_protocol = self.make_protocol_expecting_body()
 
2105
        smart_protocol.request_handler = InstrumentedRequestHandler()
 
2106
        body = (
 
2107
            'p' # body kind
 
2108
            '\0\0\0\x07' # length prefix
 
2109
            'content' # the payload
 
2110
            )
 
2111
        smart_protocol.accept_bytes(body)
 
2112
        self.assertEqual(
 
2113
            [('prefixed_body_received', 'content')],
 
2114
            smart_protocol.request_handler.calls)
 
2115
 
 
2116
    def test_chunked_body_zero_chunks(self):
 
2117
        smart_protocol = self.make_protocol_expecting_body()
 
2118
        smart_protocol.request_handler = InstrumentedRequestHandler()
 
2119
        body = (
 
2120
            's' # body kind
 
2121
            't' # stream terminator
 
2122
            )
 
2123
        smart_protocol.accept_bytes(body)
 
2124
        # XXX: No calls to the request handler in this case?  That's slightly
 
2125
        # odd.
 
2126
        self.assertEqual(
 
2127
            [],
 
2128
            smart_protocol.request_handler.calls)
 
2129
 
 
2130
    def test_chunked_body_one_chunks(self):
 
2131
        smart_protocol = self.make_protocol_expecting_body()
 
2132
        smart_protocol.request_handler = InstrumentedRequestHandler()
 
2133
        body = (
 
2134
            's' # body kind
 
2135
            'c' # chunk indicator
 
2136
            '\0\0\0\x03' # chunk length
 
2137
            'one' # chunk content
 
2138
            # Done
 
2139
            't' # stream terminator
 
2140
            )
 
2141
        smart_protocol.accept_bytes(body)
 
2142
        self.assertEqual(
 
2143
            [('body_chunk_received', 'one')],
 
2144
            smart_protocol.request_handler.calls)
 
2145
 
 
2146
    def test_chunked_body_two_chunks(self):
 
2147
        smart_protocol = self.make_protocol_expecting_body()
 
2148
        smart_protocol.request_handler = InstrumentedRequestHandler()
 
2149
        body = (
 
2150
            's' # body kind
 
2151
            # First chunk
 
2152
            'c' # chunk indicator
 
2153
            '\0\0\0\x03' # chunk length
 
2154
            'one' # chunk content
 
2155
            # Second chunk
 
2156
            'c'
 
2157
            '\0\0\0\x03'
 
2158
            'two'
 
2159
            # Done
 
2160
            't' # stream terminator
 
2161
            )
 
2162
        smart_protocol.accept_bytes(body)
 
2163
        self.assertEqual(
 
2164
            [('body_chunk_received', 'one'), ('body_chunk_received', 'two')],
 
2165
            smart_protocol.request_handler.calls)
 
2166
 
 
2167
 
 
2168
class InstrumentedRequestHandler(object):
 
2169
    """Test Double of SmartServerRequestHandler."""
 
2170
 
 
2171
    def __init__(self):
 
2172
        self.calls = []
 
2173
 
 
2174
    def body_chunk_received(self, chunk_bytes):
 
2175
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2176
 
 
2177
    def no_body_received(self):
 
2178
        self.calls.append(('no_body_received',))
 
2179
 
 
2180
    def prefixed_body_received(self, body_bytes):
 
2181
        self.calls.append(('prefixed_body_received', body_bytes))
 
2182
 
 
2183
 
 
2184
 
 
2185
#class TestProtocolTestCoverage(tests.TestCase):
 
2186
#
 
2187
#    def assertSetEqual(self, set_a, set_b):
 
2188
#        if set_a != set_b:
 
2189
#            missing_from_a = sorted(set_b - set_a)
 
2190
#            missing_from_b = sorted(set_a - set_b)
 
2191
#            raise self.failureException(
 
2192
#                'Sets not equal.\na is missing: %r\nb is missing: %r'
 
2193
#                % (missing_from_a, missing_from_b))
 
2194
#
 
2195
#    def get_tests_from_classes(self, test_case_classes):
 
2196
#        loader = unittest.TestLoader()
 
2197
#        test_names = []
 
2198
#        for test_case_class in test_case_classes:
 
2199
#            names = loader.getTestCaseNames(test_case_class)
 
2200
#            test_names.extend(names)
 
2201
#        return set(self.remove_version_specific_tests(test_names))
 
2202
#
 
2203
#    def remove_version_specific_tests(self, test_names):
 
2204
#        return [name for name in test_names
 
2205
#                if not name.startswith('test_construct_version_')]
 
2206
#    
 
2207
#    def test_ensure_consistent_coverage(self):
 
2208
#        """We should be testing the same set of conditions for all protocol
 
2209
#        implementations.
 
2210
#
 
2211
#        The implementations of those tests may differ (so we can't use simple
 
2212
#        test parameterisation to keep the tests synchronised), so this test is
 
2213
#        to ensure that all tests for v1 are done for v2 and v3, and that all v2
 
2214
#        tests are done for v3.
 
2215
#        """
 
2216
#        v1_classes = [TestVersionOneFeaturesInProtocolOne]
 
2217
#        v2_classes = [
 
2218
#            TestVersionOneFeaturesInProtocolTwo,
 
2219
#            TestVersionTwoFeaturesInProtocolTwo]
 
2220
##        v3_classes = [
 
2221
##            TestVersionOneFeaturesInProtocolThree,
 
2222
##            TestVersionTwoFeaturesInProtocolThree,
 
2223
##            TestVersionThreeFeaturesInProtocolThree]
 
2224
#
 
2225
#        # v2 implements all of v1
 
2226
#        protocol1_tests = self.get_tests_from_classes(v1_classes)
 
2227
#        protocol2_basic_tests = self.get_tests_from_class(
 
2228
#            TestVersionOneFeaturesInProtocolTwo)
 
2229
#        self.assertSetEqual(protocol1_tests, protocol2_basic_tests)
 
2230
#
 
2231
#        # v3 implements all of v1 and v2.
 
2232
#        protocol2_tests = self.get_tests_from_classes(v2_classes)
 
2233
#        protocol3_basic_tests = self.get_tests_from_class(
 
2234
#            TestVersionOneFeaturesInProtocolThree)
 
2235
#        self.assertSetEqual(protocol2_tests, protocol3_basic_tests)
2071
2236
 
2072
2237
 
2073
2238
class TestSmartClientUnicode(tests.TestCase):