1411
1416
class CommonSmartProtocolTestMixin(object):
1418
def test_server_offset_serialisation(self):
1419
"""The Smart protocol serialises offsets as a comma and \n string.
1421
We check a number of boundary cases are as expected: empty, one offset,
1422
one with the order of reads not increasing (an out of order read), and
1423
one that should coalesce.
1425
self.assertOffsetSerialisation([], '', self.client_protocol)
1426
self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1427
self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1428
self.client_protocol)
1429
self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1430
'1,2\n3,4\n100,200', self.client_protocol)
1413
1432
def test_errors_are_logged(self):
1414
1433
"""If an error occurs during testing, it is logged to the test log."""
1434
# XXX: hmm, only errors from the request handler get logged, other
1435
# protocol errors don't. I guess this is trying to test for internal
1436
# logic errors (unexpected "internal server errors", ala HTTP 500)
1437
# rather than bad requests from the client? The behaviour here needs
1415
1439
out_stream = StringIO()
1416
1440
smart_protocol = self.server_protocol_class(None, out_stream.write)
1417
1441
# This triggers a "bad request" error.
1678
1690
request = client_medium.get_request()
1679
1691
client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1681
def test_server_offset_serialisation(self):
1682
"""The Smart protocol serialises offsets as a comma and \n string.
1684
We check a number of boundary cases are as expected: empty, one offset,
1685
one with the order of reads not increasing (an out of order read), and
1686
one that should coalesce.
1688
self.assertOffsetSerialisation([], '', self.client_protocol)
1689
self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1690
self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1691
self.client_protocol)
1692
self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1693
'1,2\n3,4\n100,200', self.client_protocol)
1695
1693
def test_accept_bytes_of_bad_request_to_protocol(self):
1696
1694
out_stream = StringIO()
1697
smart_protocol = protocol.SmartServerRequestProtocolTwo(
1698
None, out_stream.write)
1695
smart_protocol = self.server_protocol_class(None, out_stream.write)
1699
1696
smart_protocol.accept_bytes('abc')
1700
1697
self.assertEqual('abc', smart_protocol.in_buffer)
1701
1698
smart_protocol.accept_bytes('\n')
1702
1699
self.assertEqual(
1703
protocol.RESPONSE_VERSION_TWO +
1700
self.response_marker +
1704
1701
"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1705
1702
out_stream.getvalue())
1706
1703
self.assertTrue(smart_protocol.has_dispatched)
1744
1740
def test_accept_excess_bytes_after_body(self):
1745
1741
# The excess bytes look like the start of another request.
1746
1742
server_protocol = self.build_protocol_waiting_for_body()
1747
server_protocol.accept_bytes(
1748
'7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
1743
server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
1749
1744
self.assertTrue(self.end_received)
1750
self.assertEqual(protocol.RESPONSE_VERSION_TWO,
1745
self.assertEqual(self.response_marker,
1751
1746
server_protocol.excess_buffer)
1752
1747
self.assertEqual("", server_protocol.in_buffer)
1753
1748
server_protocol.accept_bytes('Y')
1754
self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
1749
self.assertEqual(self.response_marker + "Y",
1755
1750
server_protocol.excess_buffer)
1756
1751
self.assertEqual("", server_protocol.in_buffer)
1758
1753
def test_accept_excess_bytes_after_dispatch(self):
1759
1754
out_stream = StringIO()
1760
smart_protocol = protocol.SmartServerRequestProtocolTwo(
1761
None, out_stream.write)
1755
smart_protocol = self.server_protocol_class(None, out_stream.write)
1762
1756
smart_protocol.accept_bytes('hello\n')
1763
self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
1757
self.assertEqual(self.response_marker + "success\nok\x012\n",
1764
1758
out_stream.getvalue())
1765
smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
1766
self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
1759
smart_protocol.accept_bytes(self.request_marker + 'hel')
1760
self.assertEqual(self.request_marker + "hel",
1767
1761
smart_protocol.excess_buffer)
1768
1762
smart_protocol.accept_bytes('lo\n')
1769
self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
1763
self.assertEqual(self.request_marker + "hello\n",
1770
1764
smart_protocol.excess_buffer)
1771
1765
self.assertEqual("", smart_protocol.in_buffer)
1773
1767
def test__send_response_sets_finished_reading(self):
1774
smart_protocol = protocol.SmartServerRequestProtocolTwo(
1775
None, lambda x: None)
1768
smart_protocol = self.server_protocol_class(None, lambda x: None)
1776
1769
self.assertEqual(1, smart_protocol.next_read_size())
1777
1770
smart_protocol._send_response(
1778
1771
request.SuccessfulSmartServerResponse(('x',)))
1807
1799
# when the parsed line is an empty line, and results in a tuple with
1808
1800
# one element - an empty string.
1809
1801
self.assertServerToClientEncoding(
1810
protocol.RESPONSE_VERSION_TWO + 'success\n\n', ('', ), [(), ('', )])
1802
self.response_marker + 'success\n\n', ('', ), [(), ('', )])
1812
1804
def test_client_call_three_element_response(self):
1813
1805
# protocol.call() can get back tuples of other lengths. A three element
1814
1806
# tuple should be unpacked as three strings.
1815
1807
self.assertServerToClientEncoding(
1816
protocol.RESPONSE_VERSION_TWO + 'success\na\x01b\x0134\n',
1808
self.response_marker + 'success\na\x01b\x0134\n',
1817
1809
('a', 'b', '34'),
1818
1810
[('a', 'b', '34')])
1820
1812
def test_client_call_with_body_bytes_uploads(self):
1821
1813
# protocol.call_with_body_bytes should length-prefix the bytes onto the
1823
expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
1815
expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
1824
1816
input = StringIO("\n")
1825
1817
output = StringIO()
1826
1818
client_medium = medium.SmartSimplePipesClientMedium(input, output)
1827
1819
request = client_medium.get_request()
1828
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1820
smart_protocol = self.client_protocol_class(request)
1829
1821
smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1830
1822
self.assertEqual(expected_bytes, output.getvalue())
1832
1824
def test_client_call_with_body_readv_array(self):
1833
1825
# protocol.call_with_upload should encode the readv array and then
1834
1826
# length-prefix the bytes onto the wire.
1835
expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
1827
expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
1836
1828
input = StringIO("\n")
1837
1829
output = StringIO()
1838
1830
client_medium = medium.SmartSimplePipesClientMedium(input, output)
1839
1831
request = client_medium.get_request()
1840
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
1832
smart_protocol = self.client_protocol_class(request)
1841
1833
smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1842
1834
self.assertEqual(expected_bytes, output.getvalue())
2038
2020
self.assertEqual(True, smart_protocol.response_status)
2041
class TestProtocolTestCoverage(tests.TestCase):
2043
def assertSetEqual(self, set_a, set_b):
2045
missing_from_a = sorted(set_b - set_a)
2046
missing_from_b = sorted(set_a - set_b)
2047
raise self.failureException(
2048
'Sets not equal.\na is missing: %r\nb is missing: %r'
2049
% (missing_from_a, missing_from_b))
2051
def remove_version_specific_tests(self, test_names):
2052
return [name for name in test_names
2053
if not name.startswith('test_construct_version_')]
2055
def test_ensure_consistent_coverage(self):
2056
"""We should be testing the same set of conditions for all protocol
2059
The implementations of those tests may differ (so we can't use simple
2060
test parameterisation to keep the tests synchronised), so this test is
2061
to ensure that the set of test methods names executed on the
2062
TestSmartProtocol{One,Two} classes are the same.
2023
class TestSmartProtocolTwoSpecifics(
2024
TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
2025
"""Tests for aspects of smart protocol version two that are unique to
2028
Thus tests involving body streams and success/failure markers belong here.
2031
client_protocol_class = protocol.SmartClientRequestProtocolTwo
2032
server_protocol_class = protocol.SmartServerRequestProtocolTwo
2035
class TestVersionOneFeaturesInProtocolThree(
2036
TestSmartProtocol, CommonSmartProtocolTestMixin):
2037
"""Tests for version one smart protocol features as implemented by version
2041
client_protocol_class = protocol.SmartClientRequestProtocolThree
2042
server_protocol_class = protocol.SmartServerRequestProtocolThree
2044
def test_construct_version_three_server_protocol(self):
2045
smart_protocol = protocol.SmartServerRequestProtocolThree(None, None)
2046
self.assertEqual('', smart_protocol.excess_buffer)
2047
self.assertEqual('', smart_protocol._in_buffer)
2048
self.assertFalse(smart_protocol.has_dispatched)
2049
# The protocol starts by scanning one byte a time, because it needs to
2050
# see a newline before it can determine the protocol version.
2051
self.assertEqual(1, smart_protocol.next_read_size())
2053
def test_construct_version_three_client_protocol(self):
2054
# we can construct a client protocol from a client medium request
2056
client_medium = medium.SmartSimplePipesClientMedium(None, output)
2057
request = client_medium.get_request()
2058
client_protocol = protocol.SmartClientRequestProtocolThree(request)
2061
class TestServerProtocolThree(TestSmartProtocol):
2062
"""Tests for v3 of the server-side protocol."""
2064
client_protocol_class = protocol.SmartClientRequestProtocolThree
2065
server_protocol_class = protocol.SmartServerRequestProtocolThree
2067
def test_trivial_request(self):
2068
"""Smoke test for the simplest possible v3 request: no headers, single
2065
loader = unittest.TestLoader()
2066
names = loader.getTestCaseNames(TestSmartProtocolOne)
2067
protocol1_tests = set(self.remove_version_specific_tests(names))
2068
names = loader.getTestCaseNames(TestSmartProtocolTwo)
2069
protocol2_tests = set(self.remove_version_specific_tests(names))
2070
self.assertSetEqual(protocol1_tests, protocol2_tests)
2072
protocol_version = "bzr request 3\n"
2073
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2074
args = '\0\0\0\x07l3:ARGe' # length-prefixed, bencoded list: ['ARG']
2076
request_bytes = headers + args + body
2077
smart_protocol = self.server_protocol_class(None, output.write)
2078
smart_protocol.accept_bytes(request_bytes)
2079
self.assertEqual(0, smart_protocol.next_read_size())
2080
self.assertEqual('', smart_protocol.excess_buffer)
2082
def make_protocol_expecting_body(self):
2084
protocol_version = "bzr request 3\n"
2085
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2086
args = '\0\0\0\x07l3:ARGe' # length-prefixed, bencoded list: ['ARG']
2087
request_bytes = headers + args
2088
smart_protocol = self.server_protocol_class(None, output.write)
2089
smart_protocol.accept_bytes(request_bytes)
2090
return smart_protocol
2092
def test_no_body(self):
2093
smart_protocol = self.make_protocol_expecting_body()
2094
smart_protocol.request_handler = InstrumentedRequestHandler()
2098
smart_protocol.accept_bytes(body)
2100
[('no_body_received',)],
2101
smart_protocol.request_handler.calls)
2103
def test_prefixed_body(self):
2104
smart_protocol = self.make_protocol_expecting_body()
2105
smart_protocol.request_handler = InstrumentedRequestHandler()
2108
'\0\0\0\x07' # length prefix
2109
'content' # the payload
2111
smart_protocol.accept_bytes(body)
2113
[('prefixed_body_received', 'content')],
2114
smart_protocol.request_handler.calls)
2116
def test_chunked_body_zero_chunks(self):
2117
smart_protocol = self.make_protocol_expecting_body()
2118
smart_protocol.request_handler = InstrumentedRequestHandler()
2121
't' # stream terminator
2123
smart_protocol.accept_bytes(body)
2124
# XXX: No calls to the request handler in this case? That's slightly
2128
smart_protocol.request_handler.calls)
2130
def test_chunked_body_one_chunks(self):
2131
smart_protocol = self.make_protocol_expecting_body()
2132
smart_protocol.request_handler = InstrumentedRequestHandler()
2135
'c' # chunk indicator
2136
'\0\0\0\x03' # chunk length
2137
'one' # chunk content
2139
't' # stream terminator
2141
smart_protocol.accept_bytes(body)
2143
[('body_chunk_received', 'one')],
2144
smart_protocol.request_handler.calls)
2146
def test_chunked_body_two_chunks(self):
2147
smart_protocol = self.make_protocol_expecting_body()
2148
smart_protocol.request_handler = InstrumentedRequestHandler()
2152
'c' # chunk indicator
2153
'\0\0\0\x03' # chunk length
2154
'one' # chunk content
2160
't' # stream terminator
2162
smart_protocol.accept_bytes(body)
2164
[('body_chunk_received', 'one'), ('body_chunk_received', 'two')],
2165
smart_protocol.request_handler.calls)
2168
class InstrumentedRequestHandler(object):
2169
"""Test Double of SmartServerRequestHandler."""
2174
def body_chunk_received(self, chunk_bytes):
2175
self.calls.append(('body_chunk_received', chunk_bytes))
2177
def no_body_received(self):
2178
self.calls.append(('no_body_received',))
2180
def prefixed_body_received(self, body_bytes):
2181
self.calls.append(('prefixed_body_received', body_bytes))
2185
#class TestProtocolTestCoverage(tests.TestCase):
2187
# def assertSetEqual(self, set_a, set_b):
2188
# if set_a != set_b:
2189
# missing_from_a = sorted(set_b - set_a)
2190
# missing_from_b = sorted(set_a - set_b)
2191
# raise self.failureException(
2192
# 'Sets not equal.\na is missing: %r\nb is missing: %r'
2193
# % (missing_from_a, missing_from_b))
2195
# def get_tests_from_classes(self, test_case_classes):
2196
# loader = unittest.TestLoader()
2198
# for test_case_class in test_case_classes:
2199
# names = loader.getTestCaseNames(test_case_class)
2200
# test_names.extend(names)
2201
# return set(self.remove_version_specific_tests(test_names))
2203
# def remove_version_specific_tests(self, test_names):
2204
# return [name for name in test_names
2205
# if not name.startswith('test_construct_version_')]
2207
# def test_ensure_consistent_coverage(self):
2208
# """We should be testing the same set of conditions for all protocol
2211
# The implementations of those tests may differ (so we can't use simple
2212
# test parameterisation to keep the tests synchronised), so this test is
2213
# to ensure that all tests for v1 are done for v2 and v3, and that all v2
2214
# tests are done for v3.
2216
# v1_classes = [TestVersionOneFeaturesInProtocolOne]
2218
# TestVersionOneFeaturesInProtocolTwo,
2219
# TestVersionTwoFeaturesInProtocolTwo]
2221
## TestVersionOneFeaturesInProtocolThree,
2222
## TestVersionTwoFeaturesInProtocolThree,
2223
## TestVersionThreeFeaturesInProtocolThree]
2225
# # v2 implements all of v1
2226
# protocol1_tests = self.get_tests_from_classes(v1_classes)
2227
# protocol2_basic_tests = self.get_tests_from_class(
2228
# TestVersionOneFeaturesInProtocolTwo)
2229
# self.assertSetEqual(protocol1_tests, protocol2_basic_tests)
2231
# # v3 implements all of v1 and v2.
2232
# protocol2_tests = self.get_tests_from_classes(v2_classes)
2233
# protocol3_basic_tests = self.get_tests_from_class(
2234
# TestVersionOneFeaturesInProtocolThree)
2235
# self.assertSetEqual(protocol2_tests, protocol3_basic_tests)
2073
2238
class TestSmartClientUnicode(tests.TestCase):