1
# Copyright (C) 2006, 2007 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
"""Tests for smart transport"""
19
# all of this deals with byte strings so this is safe
20
from cStringIO import StringIO
33
from bzrlib.smart import (
38
request as _mod_request,
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
43
from bzrlib.transport import (
49
from bzrlib.transport.http import SmartClientHTTPMediumRequest
52
class StringIOSSHVendor(object):
53
"""A SSH vendor that uses StringIO to buffer writes and answer reads."""
55
def __init__(self, read_from, write_to):
56
self.read_from = read_from
57
self.write_to = write_to
60
def connect_ssh(self, username, password, host, port, command):
61
self.calls.append(('connect_ssh', username, password, host, port,
63
return StringIOSSHConnection(self)
66
class StringIOSSHConnection(object):
67
"""A SSH connection that uses StringIO to buffer writes and answer reads."""
69
def __init__(self, vendor):
73
self.vendor.calls.append(('close', ))
75
def get_filelike_channels(self):
76
return self.vendor.read_from, self.vendor.write_to
79
class _InvalidHostnameFeature(tests.Feature):
80
"""Does 'non_existent.invalid' fail to resolve?
82
RFC 2606 states that .invalid is reserved for invalid domain names, and
83
also underscores are not a valid character in domain names. Despite this,
84
it's possible a badly misconfigured name server might decide to always
85
return an address for any name, so this feature allows us to distinguish a
86
broken system from a broken test.
91
socket.gethostbyname('non_existent.invalid')
92
except socket.gaierror:
93
# The host name failed to resolve. Good.
98
def feature_name(self):
99
return 'invalid hostname'
101
InvalidHostnameFeature = _InvalidHostnameFeature()
104
class SmartClientMediumTests(tests.TestCase):
105
"""Tests for SmartClientMedium.
107
We should create a test scenario for this: we need a server module that
108
construct the test-servers (like make_loopsocket_and_medium), and the list
109
of SmartClientMedium classes to test.
112
def make_loopsocket_and_medium(self):
113
"""Create a loopback socket for testing, and a medium aimed at it."""
114
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
115
sock.bind(('127.0.0.1', 0))
117
port = sock.getsockname()[1]
118
client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
119
return sock, client_medium
121
def receive_bytes_on_server(self, sock, bytes):
122
"""Accept a connection on sock and read 3 bytes.
124
The bytes are appended to the list bytes.
126
:return: a Thread which is running to do the accept and recv.
128
def _receive_bytes_on_server():
129
connection, address = sock.accept()
130
bytes.append(osutils.recv_all(connection, 3))
132
t = threading.Thread(target=_receive_bytes_on_server)
136
def test_construct_smart_simple_pipes_client_medium(self):
137
# the SimplePipes client medium takes two pipes:
138
# readable pipe, writeable pipe.
139
# Constructing one should just save these and do nothing.
140
# We test this by passing in None.
141
client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
143
def test_simple_pipes_client_request_type(self):
144
# SimplePipesClient should use SmartClientStreamMediumRequest's.
145
client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
146
request = client_medium.get_request()
147
self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
149
def test_simple_pipes_client_get_concurrent_requests(self):
150
# the simple_pipes client does not support pipelined requests:
151
# but it does support serial requests: we construct one after
152
# another is finished. This is a smoke test testing the integration
153
# of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
154
# classes - as the sibling classes share this logic, they do not have
155
# explicit tests for this.
157
client_medium = medium.SmartSimplePipesClientMedium(
158
None, output, 'base')
159
request = client_medium.get_request()
160
request.finished_writing()
161
request.finished_reading()
162
request2 = client_medium.get_request()
163
request2.finished_writing()
164
request2.finished_reading()
166
def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
167
# accept_bytes writes to the writeable pipe.
169
client_medium = medium.SmartSimplePipesClientMedium(
170
None, output, 'base')
171
client_medium._accept_bytes('abc')
172
self.assertEqual('abc', output.getvalue())
174
def test_simple_pipes_client_disconnect_does_nothing(self):
175
# calling disconnect does nothing.
178
client_medium = medium.SmartSimplePipesClientMedium(
179
input, output, 'base')
180
# send some bytes to ensure disconnecting after activity still does not
182
client_medium._accept_bytes('abc')
183
client_medium.disconnect()
184
self.assertFalse(input.closed)
185
self.assertFalse(output.closed)
187
def test_simple_pipes_client_accept_bytes_after_disconnect(self):
188
# calling disconnect on the client does not alter the pipe that
189
# accept_bytes writes to.
192
client_medium = medium.SmartSimplePipesClientMedium(
193
input, output, 'base')
194
client_medium._accept_bytes('abc')
195
client_medium.disconnect()
196
client_medium._accept_bytes('abc')
197
self.assertFalse(input.closed)
198
self.assertFalse(output.closed)
199
self.assertEqual('abcabc', output.getvalue())
201
def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
202
# Doing a disconnect on a new (and thus unconnected) SimplePipes medium
204
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
205
client_medium.disconnect()
207
def test_simple_pipes_client_can_always_read(self):
208
# SmartSimplePipesClientMedium is never disconnected, so read_bytes
209
# always tries to read from the underlying pipe.
210
input = StringIO('abcdef')
211
client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
212
self.assertEqual('abc', client_medium.read_bytes(3))
213
client_medium.disconnect()
214
self.assertEqual('def', client_medium.read_bytes(3))
216
def test_simple_pipes_client_supports__flush(self):
217
# invoking _flush on a SimplePipesClient should flush the output
218
# pipe. We test this by creating an output pipe that records
219
# flush calls made to it.
220
from StringIO import StringIO # get regular StringIO
224
def logging_flush(): flush_calls.append('flush')
225
output.flush = logging_flush
226
client_medium = medium.SmartSimplePipesClientMedium(
227
input, output, 'base')
228
# this call is here to ensure we only flush once, not on every
229
# _accept_bytes call.
230
client_medium._accept_bytes('abc')
231
client_medium._flush()
232
client_medium.disconnect()
233
self.assertEqual(['flush'], flush_calls)
235
def test_construct_smart_ssh_client_medium(self):
236
# the SSH client medium takes:
237
# host, port, username, password, vendor
238
# Constructing one should just save these and do nothing.
239
# we test this by creating a empty bound socket and constructing
241
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
242
sock.bind(('127.0.0.1', 0))
243
unopened_port = sock.getsockname()[1]
244
# having vendor be invalid means that if it tries to connect via the
245
# vendor it will blow up.
246
client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
247
username=None, password=None, base='base', vendor="not a vendor",
248
bzr_remote_path='bzr')
251
def test_ssh_client_connects_on_first_use(self):
252
# The only thing that initiates a connection from the medium is giving
255
vendor = StringIOSSHVendor(StringIO(), output)
256
client_medium = medium.SmartSSHClientMedium(
257
'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
259
client_medium._accept_bytes('abc')
260
self.assertEqual('abc', output.getvalue())
261
self.assertEqual([('connect_ssh', 'a username', 'a password',
262
'a hostname', 'a port',
263
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
266
def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
267
# The only thing that initiates a connection from the medium is giving
270
vendor = StringIOSSHVendor(StringIO(), output)
271
orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
272
def cleanup_environ():
273
osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
274
self.addCleanup(cleanup_environ)
275
os.environ['BZR_REMOTE_PATH'] = 'fugly'
276
client_medium = self.callDeprecated(
277
['bzr_remote_path is required as of bzr 0.92'],
278
medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
279
'a password', 'base', vendor)
280
client_medium._accept_bytes('abc')
281
self.assertEqual('abc', output.getvalue())
282
self.assertEqual([('connect_ssh', 'a username', 'a password',
283
'a hostname', 'a port',
284
['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
287
def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
288
# The only thing that initiates a connection from the medium is giving
291
vendor = StringIOSSHVendor(StringIO(), output)
292
client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
293
'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
294
client_medium._accept_bytes('abc')
295
self.assertEqual('abc', output.getvalue())
296
self.assertEqual([('connect_ssh', 'a username', 'a password',
297
'a hostname', 'a port',
298
['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
301
def test_ssh_client_disconnect_does_so(self):
302
# calling disconnect should disconnect both the read_from and write_to
303
# file-like object it from the ssh connection.
306
vendor = StringIOSSHVendor(input, output)
307
client_medium = medium.SmartSSHClientMedium(
308
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
309
client_medium._accept_bytes('abc')
310
client_medium.disconnect()
311
self.assertTrue(input.closed)
312
self.assertTrue(output.closed)
314
('connect_ssh', None, None, 'a hostname', None,
315
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
320
def test_ssh_client_disconnect_allows_reconnection(self):
321
# calling disconnect on the client terminates the connection, but should
322
# not prevent additional connections occuring.
323
# we test this by initiating a second connection after doing a
327
vendor = StringIOSSHVendor(input, output)
328
client_medium = medium.SmartSSHClientMedium(
329
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
330
client_medium._accept_bytes('abc')
331
client_medium.disconnect()
332
# the disconnect has closed output, so we need a new output for the
333
# new connection to write to.
336
vendor.read_from = input2
337
vendor.write_to = output2
338
client_medium._accept_bytes('abc')
339
client_medium.disconnect()
340
self.assertTrue(input.closed)
341
self.assertTrue(output.closed)
342
self.assertTrue(input2.closed)
343
self.assertTrue(output2.closed)
345
('connect_ssh', None, None, 'a hostname', None,
346
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
348
('connect_ssh', None, None, 'a hostname', None,
349
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
354
def test_ssh_client_ignores_disconnect_when_not_connected(self):
355
# Doing a disconnect on a new (and thus unconnected) SSH medium
356
# does not fail. It's ok to disconnect an unconnected medium.
357
client_medium = medium.SmartSSHClientMedium(
358
None, base='base', bzr_remote_path='bzr')
359
client_medium.disconnect()
361
def test_ssh_client_raises_on_read_when_not_connected(self):
362
# Doing a read on a new (and thus unconnected) SSH medium raises
363
# MediumNotConnected.
364
client_medium = medium.SmartSSHClientMedium(
365
None, base='base', bzr_remote_path='bzr')
366
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
368
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
371
def test_ssh_client_supports__flush(self):
372
# invoking _flush on a SSHClientMedium should flush the output
373
# pipe. We test this by creating an output pipe that records
374
# flush calls made to it.
375
from StringIO import StringIO # get regular StringIO
379
def logging_flush(): flush_calls.append('flush')
380
output.flush = logging_flush
381
vendor = StringIOSSHVendor(input, output)
382
client_medium = medium.SmartSSHClientMedium(
383
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
384
# this call is here to ensure we only flush once, not on every
385
# _accept_bytes call.
386
client_medium._accept_bytes('abc')
387
client_medium._flush()
388
client_medium.disconnect()
389
self.assertEqual(['flush'], flush_calls)
391
def test_construct_smart_tcp_client_medium(self):
392
# the TCP client medium takes a host and a port. Constructing it won't
393
# connect to anything.
394
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
395
sock.bind(('127.0.0.1', 0))
396
unopened_port = sock.getsockname()[1]
397
client_medium = medium.SmartTCPClientMedium(
398
'127.0.0.1', unopened_port, 'base')
401
def test_tcp_client_connects_on_first_use(self):
402
# The only thing that initiates a connection from the medium is giving
404
sock, medium = self.make_loopsocket_and_medium()
406
t = self.receive_bytes_on_server(sock, bytes)
407
medium.accept_bytes('abc')
410
self.assertEqual(['abc'], bytes)
412
def test_tcp_client_disconnect_does_so(self):
413
# calling disconnect on the client terminates the connection.
414
# we test this by forcing a short read during a socket.MSG_WAITALL
415
# call: write 2 bytes, try to read 3, and then the client disconnects.
416
sock, medium = self.make_loopsocket_and_medium()
418
t = self.receive_bytes_on_server(sock, bytes)
419
medium.accept_bytes('ab')
423
self.assertEqual(['ab'], bytes)
424
# now disconnect again: this should not do anything, if disconnection
425
# really did disconnect.
429
def test_tcp_client_ignores_disconnect_when_not_connected(self):
430
# Doing a disconnect on a new (and thus unconnected) TCP medium
431
# does not fail. It's ok to disconnect an unconnected medium.
432
client_medium = medium.SmartTCPClientMedium(None, None, None)
433
client_medium.disconnect()
435
def test_tcp_client_raises_on_read_when_not_connected(self):
436
# Doing a read on a new (and thus unconnected) TCP medium raises
437
# MediumNotConnected.
438
client_medium = medium.SmartTCPClientMedium(None, None, None)
439
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
440
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
442
def test_tcp_client_supports__flush(self):
443
# invoking _flush on a TCPClientMedium should do something useful.
444
# RBC 20060922 not sure how to test/tell in this case.
445
sock, medium = self.make_loopsocket_and_medium()
447
t = self.receive_bytes_on_server(sock, bytes)
448
# try with nothing buffered
450
medium._accept_bytes('ab')
451
# and with something sent.
456
self.assertEqual(['ab'], bytes)
457
# now disconnect again : this should not do anything, if disconnection
458
# really did disconnect.
461
def test_tcp_client_host_unknown_connection_error(self):
462
self.requireFeature(InvalidHostnameFeature)
463
client_medium = medium.SmartTCPClientMedium(
464
'non_existent.invalid', 4155, 'base')
466
errors.ConnectionError, client_medium._ensure_connection)
469
class TestSmartClientStreamMediumRequest(tests.TestCase):
470
"""Tests the for SmartClientStreamMediumRequest.
472
SmartClientStreamMediumRequest is a helper for the three stream based
473
mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
474
those three mediums implement the interface it expects.
477
def test_accept_bytes_after_finished_writing_errors(self):
478
# calling accept_bytes after calling finished_writing raises
479
# WritingCompleted to prevent bad assumptions on stream environments
480
# breaking the needs of message-based environments.
482
client_medium = medium.SmartSimplePipesClientMedium(
483
None, output, 'base')
484
request = medium.SmartClientStreamMediumRequest(client_medium)
485
request.finished_writing()
486
self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
488
def test_accept_bytes(self):
489
# accept bytes should invoke _accept_bytes on the stream medium.
490
# we test this by using the SimplePipes medium - the most trivial one
491
# and checking that the pipes get the data.
494
client_medium = medium.SmartSimplePipesClientMedium(
495
input, output, 'base')
496
request = medium.SmartClientStreamMediumRequest(client_medium)
497
request.accept_bytes('123')
498
request.finished_writing()
499
request.finished_reading()
500
self.assertEqual('', input.getvalue())
501
self.assertEqual('123', output.getvalue())
503
def test_construct_sets_stream_request(self):
504
# constructing a SmartClientStreamMediumRequest on a StreamMedium sets
505
# the current request to the new SmartClientStreamMediumRequest
507
client_medium = medium.SmartSimplePipesClientMedium(
508
None, output, 'base')
509
request = medium.SmartClientStreamMediumRequest(client_medium)
510
self.assertIs(client_medium._current_request, request)
512
def test_construct_while_another_request_active_throws(self):
513
# constructing a SmartClientStreamMediumRequest on a StreamMedium with
514
# a non-None _current_request raises TooManyConcurrentRequests.
516
client_medium = medium.SmartSimplePipesClientMedium(
517
None, output, 'base')
518
client_medium._current_request = "a"
519
self.assertRaises(errors.TooManyConcurrentRequests,
520
medium.SmartClientStreamMediumRequest, client_medium)
522
def test_finished_read_clears_current_request(self):
523
# calling finished_reading clears the current request from the requests
526
client_medium = medium.SmartSimplePipesClientMedium(
527
None, output, 'base')
528
request = medium.SmartClientStreamMediumRequest(client_medium)
529
request.finished_writing()
530
request.finished_reading()
531
self.assertEqual(None, client_medium._current_request)
533
def test_finished_read_before_finished_write_errors(self):
534
# calling finished_reading before calling finished_writing triggers a
535
# WritingNotComplete error.
536
client_medium = medium.SmartSimplePipesClientMedium(
538
request = medium.SmartClientStreamMediumRequest(client_medium)
539
self.assertRaises(errors.WritingNotComplete, request.finished_reading)
541
def test_read_bytes(self):
542
# read bytes should invoke _read_bytes on the stream medium.
543
# we test this by using the SimplePipes medium - the most trivial one
544
# and checking that the data is supplied. Its possible that a
545
# faulty implementation could poke at the pipe variables them selves,
546
# but we trust that this will be caught as it will break the integration
548
input = StringIO('321')
550
client_medium = medium.SmartSimplePipesClientMedium(
551
input, output, 'base')
552
request = medium.SmartClientStreamMediumRequest(client_medium)
553
request.finished_writing()
554
self.assertEqual('321', request.read_bytes(3))
555
request.finished_reading()
556
self.assertEqual('', input.read())
557
self.assertEqual('', output.getvalue())
559
def test_read_bytes_before_finished_write_errors(self):
560
# calling read_bytes before calling finished_writing triggers a
561
# WritingNotComplete error because the Smart protocol is designed to be
562
# compatible with strict message based protocols like HTTP where the
563
# request cannot be submitted until the writing has completed.
564
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
565
request = medium.SmartClientStreamMediumRequest(client_medium)
566
self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
568
def test_read_bytes_after_finished_reading_errors(self):
569
# calling read_bytes after calling finished_reading raises
570
# ReadingCompleted to prevent bad assumptions on stream environments
571
# breaking the needs of message-based environments.
573
client_medium = medium.SmartSimplePipesClientMedium(
574
None, output, 'base')
575
request = medium.SmartClientStreamMediumRequest(client_medium)
576
request.finished_writing()
577
request.finished_reading()
578
self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
581
class RemoteTransportTests(TestCaseWithSmartMedium):
583
def test_plausible_url(self):
584
self.assert_(self.get_url().startswith('bzr://'))
586
def test_probe_transport(self):
587
t = self.get_transport()
588
self.assertIsInstance(t, remote.RemoteTransport)
590
def test_get_medium_from_transport(self):
591
"""Remote transport has a medium always, which it can return."""
592
t = self.get_transport()
593
client_medium = t.get_smart_medium()
594
self.assertIsInstance(client_medium, medium.SmartClientMedium)
597
class ErrorRaisingProtocol(object):
599
def __init__(self, exception):
600
self.exception = exception
602
def next_read_size(self):
606
class SampleRequest(object):
608
def __init__(self, expected_bytes):
609
self.accepted_bytes = ''
610
self._finished_reading = False
611
self.expected_bytes = expected_bytes
612
self.unused_data = ''
614
def accept_bytes(self, bytes):
615
self.accepted_bytes += bytes
616
if self.accepted_bytes.startswith(self.expected_bytes):
617
self._finished_reading = True
618
self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
620
def next_read_size(self):
621
if self._finished_reading:
627
class TestSmartServerStreamMedium(tests.TestCase):
630
super(TestSmartServerStreamMedium, self).setUp()
631
self._captureVar('BZR_NO_SMART_VFS', None)
633
def portable_socket_pair(self):
634
"""Return a pair of TCP sockets connected to each other.
636
Unlike socket.socketpair, this should work on Windows.
638
listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
639
listen_sock.bind(('127.0.0.1', 0))
640
listen_sock.listen(1)
641
client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
642
client_sock.connect(listen_sock.getsockname())
643
server_sock, addr = listen_sock.accept()
645
return server_sock, client_sock
647
def test_smart_query_version(self):
648
"""Feed a canned query version to a server"""
649
# wire-to-wire, using the whole stack
650
to_server = StringIO('hello\n')
651
from_server = StringIO()
652
transport = local.LocalTransport(urlutils.local_path_to_url('/'))
653
server = medium.SmartServerPipeStreamMedium(
654
to_server, from_server, transport)
655
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
657
server._serve_one_request(smart_protocol)
658
self.assertEqual('ok\0012\n',
659
from_server.getvalue())
661
def test_response_to_canned_get(self):
662
transport = memory.MemoryTransport('memory:///')
663
transport.put_bytes('testfile', 'contents\nof\nfile\n')
664
to_server = StringIO('get\001./testfile\n')
665
from_server = StringIO()
666
server = medium.SmartServerPipeStreamMedium(
667
to_server, from_server, transport)
668
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
670
server._serve_one_request(smart_protocol)
671
self.assertEqual('ok\n'
673
'contents\nof\nfile\n'
675
from_server.getvalue())
677
def test_response_to_canned_get_of_utf8(self):
678
# wire-to-wire, using the whole stack, with a UTF-8 filename.
679
transport = memory.MemoryTransport('memory:///')
680
utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
681
transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
682
to_server = StringIO('get\001' + utf8_filename + '\n')
683
from_server = StringIO()
684
server = medium.SmartServerPipeStreamMedium(
685
to_server, from_server, transport)
686
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
688
server._serve_one_request(smart_protocol)
689
self.assertEqual('ok\n'
691
'contents\nof\nfile\n'
693
from_server.getvalue())
695
def test_pipe_like_stream_with_bulk_data(self):
696
sample_request_bytes = 'command\n9\nbulk datadone\n'
697
to_server = StringIO(sample_request_bytes)
698
from_server = StringIO()
699
server = medium.SmartServerPipeStreamMedium(
700
to_server, from_server, None)
701
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
server._serve_one_request(sample_protocol)
703
self.assertEqual('', from_server.getvalue())
704
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
705
self.assertFalse(server.finished)
707
def test_socket_stream_with_bulk_data(self):
708
sample_request_bytes = 'command\n9\nbulk datadone\n'
709
server_sock, client_sock = self.portable_socket_pair()
710
server = medium.SmartServerSocketStreamMedium(
712
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
713
client_sock.sendall(sample_request_bytes)
714
server._serve_one_request(sample_protocol)
716
self.assertEqual('', client_sock.recv(1))
717
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
718
self.assertFalse(server.finished)
720
def test_pipe_like_stream_shutdown_detection(self):
721
to_server = StringIO('')
722
from_server = StringIO()
723
server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
724
server._serve_one_request(SampleRequest('x'))
725
self.assertTrue(server.finished)
727
def test_socket_stream_shutdown_detection(self):
728
server_sock, client_sock = self.portable_socket_pair()
730
server = medium.SmartServerSocketStreamMedium(
732
server._serve_one_request(SampleRequest('x'))
733
self.assertTrue(server.finished)
735
def test_socket_stream_incomplete_request(self):
736
"""The medium should still construct the right protocol version even if
737
the initial read only reads part of the request.
739
Specifically, it should correctly read the protocol version line even
740
if the partial read doesn't end in a newline. An older, naive
741
implementation of _get_line in the server used to have a bug in that
744
incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
745
rest_of_request_bytes = 'lo\n'
746
expected_response = (
747
protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
748
server_sock, client_sock = self.portable_socket_pair()
749
server = medium.SmartServerSocketStreamMedium(
751
client_sock.sendall(incomplete_request_bytes)
752
server_protocol = server._build_protocol()
753
client_sock.sendall(rest_of_request_bytes)
754
server._serve_one_request(server_protocol)
756
self.assertEqual(expected_response, client_sock.recv(50),
757
"Not a version 2 response to 'hello' request.")
758
self.assertEqual('', client_sock.recv(1))
760
def test_pipe_stream_incomplete_request(self):
761
"""The medium should still construct the right protocol version even if
762
the initial read only reads part of the request.
764
Specifically, it should correctly read the protocol version line even
765
if the partial read doesn't end in a newline. An older, naive
766
implementation of _get_line in the server used to have a bug in that
769
incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
770
rest_of_request_bytes = 'lo\n'
771
expected_response = (
772
protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
773
# Make a pair of pipes, to and from the server
774
to_server, to_server_w = os.pipe()
775
from_server_r, from_server = os.pipe()
776
to_server = os.fdopen(to_server, 'r', 0)
777
to_server_w = os.fdopen(to_server_w, 'w', 0)
778
from_server_r = os.fdopen(from_server_r, 'r', 0)
779
from_server = os.fdopen(from_server, 'w', 0)
780
server = medium.SmartServerPipeStreamMedium(
781
to_server, from_server, None)
782
# Like test_socket_stream_incomplete_request, write an incomplete
783
# request (that does not end in '\n') and build a protocol from it.
784
to_server_w.write(incomplete_request_bytes)
785
server_protocol = server._build_protocol()
786
# Send the rest of the request, and finish serving it.
787
to_server_w.write(rest_of_request_bytes)
788
server._serve_one_request(server_protocol)
791
self.assertEqual(expected_response, from_server_r.read(),
792
"Not a version 2 response to 'hello' request.")
793
self.assertEqual('', from_server_r.read(1))
794
from_server_r.close()
797
def test_pipe_like_stream_with_two_requests(self):
798
# If two requests are read in one go, then two calls to
799
# _serve_one_request should still process both of them as if they had
800
# been received seperately.
801
sample_request_bytes = 'command\n'
802
to_server = StringIO(sample_request_bytes * 2)
803
from_server = StringIO()
804
server = medium.SmartServerPipeStreamMedium(
805
to_server, from_server, None)
806
first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
807
server._serve_one_request(first_protocol)
808
self.assertEqual(0, first_protocol.next_read_size())
809
self.assertEqual('', from_server.getvalue())
810
self.assertFalse(server.finished)
811
# Make a new protocol, call _serve_one_request with it to collect the
813
second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
814
server._serve_one_request(second_protocol)
815
self.assertEqual('', from_server.getvalue())
816
self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
817
self.assertFalse(server.finished)
819
def test_socket_stream_with_two_requests(self):
820
# If two requests are read in one go, then two calls to
821
# _serve_one_request should still process both of them as if they had
822
# been received seperately.
823
sample_request_bytes = 'command\n'
824
server_sock, client_sock = self.portable_socket_pair()
825
server = medium.SmartServerSocketStreamMedium(
827
first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
828
# Put two whole requests on the wire.
829
client_sock.sendall(sample_request_bytes * 2)
830
server._serve_one_request(first_protocol)
831
self.assertEqual(0, first_protocol.next_read_size())
832
self.assertFalse(server.finished)
833
# Make a new protocol, call _serve_one_request with it to collect the
835
second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
836
stream_still_open = server._serve_one_request(second_protocol)
837
self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
838
self.assertFalse(server.finished)
840
self.assertEqual('', client_sock.recv(1))
842
def test_pipe_like_stream_error_handling(self):
843
# Use plain python StringIO so we can monkey-patch the close method to
844
# not discard the contents.
845
from StringIO import StringIO
846
to_server = StringIO('')
847
from_server = StringIO()
851
from_server.close = close
852
server = medium.SmartServerPipeStreamMedium(
853
to_server, from_server, None)
854
fake_protocol = ErrorRaisingProtocol(Exception('boom'))
855
server._serve_one_request(fake_protocol)
856
self.assertEqual('', from_server.getvalue())
857
self.assertTrue(self.closed)
858
self.assertTrue(server.finished)
860
def test_socket_stream_error_handling(self):
861
server_sock, client_sock = self.portable_socket_pair()
862
server = medium.SmartServerSocketStreamMedium(
864
fake_protocol = ErrorRaisingProtocol(Exception('boom'))
865
server._serve_one_request(fake_protocol)
866
# recv should not block, because the other end of the socket has been
868
self.assertEqual('', client_sock.recv(1))
869
self.assertTrue(server.finished)
871
def test_pipe_like_stream_keyboard_interrupt_handling(self):
872
to_server = StringIO('')
873
from_server = StringIO()
874
server = medium.SmartServerPipeStreamMedium(
875
to_server, from_server, None)
876
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
878
KeyboardInterrupt, server._serve_one_request, fake_protocol)
879
self.assertEqual('', from_server.getvalue())
881
def test_socket_stream_keyboard_interrupt_handling(self):
882
server_sock, client_sock = self.portable_socket_pair()
883
server = medium.SmartServerSocketStreamMedium(
885
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
887
KeyboardInterrupt, server._serve_one_request, fake_protocol)
889
self.assertEqual('', client_sock.recv(1))
891
def build_protocol_pipe_like(self, bytes):
892
to_server = StringIO(bytes)
893
from_server = StringIO()
894
server = medium.SmartServerPipeStreamMedium(
895
to_server, from_server, None)
896
return server._build_protocol()
898
def build_protocol_socket(self, bytes):
899
server_sock, client_sock = self.portable_socket_pair()
900
server = medium.SmartServerSocketStreamMedium(
902
client_sock.sendall(bytes)
904
return server._build_protocol()
906
def assertProtocolOne(self, server_protocol):
907
# Use assertIs because assertIsInstance will wrongly pass
908
# SmartServerRequestProtocolTwo (because it subclasses
909
# SmartServerRequestProtocolOne).
911
type(server_protocol), protocol.SmartServerRequestProtocolOne)
913
def assertProtocolTwo(self, server_protocol):
914
self.assertIsInstance(
915
server_protocol, protocol.SmartServerRequestProtocolTwo)
917
def test_pipe_like_build_protocol_empty_bytes(self):
918
# Any empty request (i.e. no bytes) is detected as protocol version one.
919
server_protocol = self.build_protocol_pipe_like('')
920
self.assertProtocolOne(server_protocol)
922
def test_socket_like_build_protocol_empty_bytes(self):
923
# Any empty request (i.e. no bytes) is detected as protocol version one.
924
server_protocol = self.build_protocol_socket('')
925
self.assertProtocolOne(server_protocol)
927
def test_pipe_like_build_protocol_non_two(self):
928
# A request that doesn't start with "bzr request 2\n" is version one.
929
server_protocol = self.build_protocol_pipe_like('abc\n')
930
self.assertProtocolOne(server_protocol)
932
def test_socket_build_protocol_non_two(self):
933
# A request that doesn't start with "bzr request 2\n" is version one.
934
server_protocol = self.build_protocol_socket('abc\n')
935
self.assertProtocolOne(server_protocol)
937
def test_pipe_like_build_protocol_two(self):
938
# A request that starts with "bzr request 2\n" is version two.
939
server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
940
self.assertProtocolTwo(server_protocol)
942
def test_socket_build_protocol_two(self):
943
# A request that starts with "bzr request 2\n" is version two.
944
server_protocol = self.build_protocol_socket('bzr request 2\n')
945
self.assertProtocolTwo(server_protocol)
948
class TestGetProtocolFactoryForBytes(tests.TestCase):
949
"""_get_protocol_factory_for_bytes identifies the protocol factory a server
950
should use to decode a given request. Any bytes not part of the version
951
marker string (and thus part of the actual request) are returned alongside
952
the protocol factory.
955
def test_version_three(self):
956
result = medium._get_protocol_factory_for_bytes(
957
'bzr message 3 (bzr 1.6)\nextra bytes')
958
protocol_factory, remainder = result
960
protocol.build_server_protocol_three, protocol_factory)
961
self.assertEqual('extra bytes', remainder)
963
def test_version_two(self):
964
result = medium._get_protocol_factory_for_bytes(
965
'bzr request 2\nextra bytes')
966
protocol_factory, remainder = result
968
protocol.SmartServerRequestProtocolTwo, protocol_factory)
969
self.assertEqual('extra bytes', remainder)
971
def test_version_one(self):
972
"""Version one requests have no version markers."""
973
result = medium._get_protocol_factory_for_bytes('anything\n')
974
protocol_factory, remainder = result
976
protocol.SmartServerRequestProtocolOne, protocol_factory)
977
self.assertEqual('anything\n', remainder)
980
class TestSmartTCPServer(tests.TestCase):
982
def test_get_error_unexpected(self):
983
"""Error reported by server with no specific representation"""
984
self._captureVar('BZR_NO_SMART_VFS', None)
985
class FlakyTransport(object):
987
def external_url(self):
989
def get_bytes(self, path):
990
raise Exception("some random exception from inside server")
991
smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
992
smart_server.start_background_thread('-' + self.id())
994
transport = remote.RemoteTCPTransport(smart_server.get_url())
996
transport.get('something')
997
except errors.TransportError, e:
998
self.assertContainsRe(str(e), 'some random exception')
1000
self.fail("get did not raise expected error")
1001
transport.disconnect()
1003
smart_server.stop_background_thread()
1006
class SmartTCPTests(tests.TestCase):
1007
"""Tests for connection/end to end behaviour using the TCP server.
1009
All of these tests are run with a server running on another thread serving
1010
a MemoryTransport, and a connection to it already open.
1012
the server is obtained by calling self.setUpServer(readonly=False).
1015
def setUpServer(self, readonly=False, backing_transport=None):
1016
"""Setup the server.
1018
:param readonly: Create a readonly server.
1020
if not backing_transport:
1021
self.backing_transport = memory.MemoryTransport()
1023
self.backing_transport = backing_transport
1025
self.real_backing_transport = self.backing_transport
1026
self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
1027
self.server = server.SmartTCPServer(self.backing_transport)
1028
self.server.start_background_thread('-' + self.id())
1029
self.transport = remote.RemoteTCPTransport(self.server.get_url())
1030
self.addCleanup(self.tearDownServer)
1032
def tearDownServer(self):
1033
if getattr(self, 'transport', None):
1034
self.transport.disconnect()
1036
if getattr(self, 'server', None):
1037
self.server.stop_background_thread()
1041
class TestServerSocketUsage(SmartTCPTests):
1043
def test_server_setup_teardown(self):
1044
"""It should be safe to teardown the server with no requests."""
1046
server = self.server
1047
transport = remote.RemoteTCPTransport(self.server.get_url())
1048
self.tearDownServer()
1049
self.assertRaises(errors.ConnectionError, transport.has, '.')
1051
def test_server_closes_listening_sock_on_shutdown_after_request(self):
1052
"""The server should close its listening socket when it's stopped."""
1054
server = self.server
1055
self.transport.has('.')
1056
self.tearDownServer()
1057
# if the listening socket has closed, we should get a BADFD error
1058
# when connecting, rather than a hang.
1059
transport = remote.RemoteTCPTransport(server.get_url())
1060
self.assertRaises(errors.ConnectionError, transport.has, '.')
1063
class WritableEndToEndTests(SmartTCPTests):
1064
"""Client to server tests that require a writable transport."""
1067
super(WritableEndToEndTests, self).setUp()
1070
def test_start_tcp_server(self):
1071
url = self.server.get_url()
1072
self.assertContainsRe(url, r'^bzr://127\.0\.0\.1:[0-9]{2,}/')
1074
def test_smart_transport_has(self):
1075
"""Checking for file existence over smart."""
1076
self._captureVar('BZR_NO_SMART_VFS', None)
1077
self.backing_transport.put_bytes("foo", "contents of foo\n")
1078
self.assertTrue(self.transport.has("foo"))
1079
self.assertFalse(self.transport.has("non-foo"))
1081
def test_smart_transport_get(self):
1082
"""Read back a file over smart."""
1083
self._captureVar('BZR_NO_SMART_VFS', None)
1084
self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
1085
fp = self.transport.get("foo")
1086
self.assertEqual('contents\nof\nfoo\n', fp.read())
1088
def test_get_error_enoent(self):
1089
"""Error reported from server getting nonexistent file."""
1090
# The path in a raised NoSuchFile exception should be the precise path
1091
# asked for by the client. This gives meaningful and unsurprising errors
1093
self._captureVar('BZR_NO_SMART_VFS', None)
1095
self.transport.get('not%20a%20file')
1096
except errors.NoSuchFile, e:
1097
self.assertEqual('not%20a%20file', e.path)
1099
self.fail("get did not raise expected error")
1101
def test_simple_clone_conn(self):
1102
"""Test that cloning reuses the same connection."""
1103
# we create a real connection not a loopback one, but it will use the
1104
# same server and pipes
1105
conn2 = self.transport.clone('.')
1106
self.assertIs(self.transport.get_smart_medium(),
1107
conn2.get_smart_medium())
1109
def test__remote_path(self):
1110
self.assertEquals('/foo/bar',
1111
self.transport._remote_path('foo/bar'))
1113
def test_clone_changes_base(self):
1114
"""Cloning transport produces one with a new base location"""
1115
conn2 = self.transport.clone('subdir')
1116
self.assertEquals(self.transport.base + 'subdir/',
1119
def test_open_dir(self):
1120
"""Test changing directory"""
1121
self._captureVar('BZR_NO_SMART_VFS', None)
1122
transport = self.transport
1123
self.backing_transport.mkdir('toffee')
1124
self.backing_transport.mkdir('toffee/apple')
1125
self.assertEquals('/toffee', transport._remote_path('toffee'))
1126
toffee_trans = transport.clone('toffee')
1127
# Check that each transport has only the contents of its directory
1128
# directly visible. If state was being held in the wrong object, it's
1129
# conceivable that cloning a transport would alter the state of the
1130
# cloned-from transport.
1131
self.assertTrue(transport.has('toffee'))
1132
self.assertFalse(toffee_trans.has('toffee'))
1133
self.assertFalse(transport.has('apple'))
1134
self.assertTrue(toffee_trans.has('apple'))
1136
def test_open_bzrdir(self):
1137
"""Open an existing bzrdir over smart transport"""
1138
transport = self.transport
1139
t = self.backing_transport
1140
bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
1141
result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
1144
class ReadOnlyEndToEndTests(SmartTCPTests):
1145
"""Tests from the client to the server using a readonly backing transport."""
1147
def test_mkdir_error_readonly(self):
1148
"""TransportNotPossible should be preserved from the backing transport."""
1149
self._captureVar('BZR_NO_SMART_VFS', None)
1150
self.setUpServer(readonly=True)
1151
self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1155
class TestServerHooks(SmartTCPTests):
1157
def capture_server_call(self, backing_urls, public_url):
1158
"""Record a server_started|stopped hook firing."""
1159
self.hook_calls.append((backing_urls, public_url))
1161
def test_server_started_hook_memory(self):
1162
"""The server_started hook fires when the server is started."""
1163
self.hook_calls = []
1164
server.SmartTCPServer.hooks.install_named_hook('server_started',
1165
self.capture_server_call, None)
1167
# at this point, the server will be starting a thread up.
1168
# there is no indicator at the moment, so bodge it by doing a request.
1169
self.transport.has('.')
1170
# The default test server uses MemoryTransport and that has no external
1172
self.assertEqual([([self.backing_transport.base], self.transport.base)],
1175
def test_server_started_hook_file(self):
1176
"""The server_started hook fires when the server is started."""
1177
self.hook_calls = []
1178
server.SmartTCPServer.hooks.install_named_hook('server_started',
1179
self.capture_server_call, None)
1180
self.setUpServer(backing_transport=get_transport("."))
1181
# at this point, the server will be starting a thread up.
1182
# there is no indicator at the moment, so bodge it by doing a request.
1183
self.transport.has('.')
1184
# The default test server uses MemoryTransport and that has no external
1186
self.assertEqual([([
1187
self.backing_transport.base, self.backing_transport.external_url()],
1188
self.transport.base)],
1191
def test_server_stopped_hook_simple_memory(self):
1192
"""The server_stopped hook fires when the server is stopped."""
1193
self.hook_calls = []
1194
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1195
self.capture_server_call, None)
1197
result = [([self.backing_transport.base], self.transport.base)]
1198
# check the stopping message isn't emitted up front.
1199
self.assertEqual([], self.hook_calls)
1200
# nor after a single message
1201
self.transport.has('.')
1202
self.assertEqual([], self.hook_calls)
1203
# clean up the server
1204
self.tearDownServer()
1205
# now it should have fired.
1206
self.assertEqual(result, self.hook_calls)
1208
def test_server_stopped_hook_simple_file(self):
1209
"""The server_stopped hook fires when the server is stopped."""
1210
self.hook_calls = []
1211
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1212
self.capture_server_call, None)
1213
self.setUpServer(backing_transport=get_transport("."))
1215
[self.backing_transport.base, self.backing_transport.external_url()]
1216
, self.transport.base)]
1217
# check the stopping message isn't emitted up front.
1218
self.assertEqual([], self.hook_calls)
1219
# nor after a single message
1220
self.transport.has('.')
1221
self.assertEqual([], self.hook_calls)
1222
# clean up the server
1223
self.tearDownServer()
1224
# now it should have fired.
1225
self.assertEqual(result, self.hook_calls)
1227
# TODO: test that when the server suffers an exception that it calls the
1228
# server-stopped hook.
1231
class SmartServerCommandTests(tests.TestCaseWithTransport):
1232
"""Tests that call directly into the command objects, bypassing the network
1233
and the request dispatching.
1235
Note: these tests are rudimentary versions of the command object tests in
1239
def test_hello(self):
1240
cmd = _mod_request.HelloRequest(None, '/')
1241
response = cmd.execute()
1242
self.assertEqual(('ok', '2'), response.args)
1243
self.assertEqual(None, response.body)
1245
def test_get_bundle(self):
1246
from bzrlib.bundle import serializer
1247
wt = self.make_branch_and_tree('.')
1248
self.build_tree_contents([('hello', 'hello world')])
1250
rev_id = wt.commit('add hello')
1252
cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1253
response = cmd.execute('.', rev_id)
1254
bundle = serializer.read_bundle(StringIO(response.body))
1255
self.assertEqual((), response.args)
1258
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
1259
"""Test that call directly into the handler logic, bypassing the network."""
1262
super(SmartServerRequestHandlerTests, self).setUp()
1263
self._captureVar('BZR_NO_SMART_VFS', None)
1265
def build_handler(self, transport):
1266
"""Returns a handler for the commands in protocol version one."""
1267
return _mod_request.SmartServerRequestHandler(
1268
transport, _mod_request.request_handlers, '/')
1270
def test_construct_request_handler(self):
1271
"""Constructing a request handler should be easy and set defaults."""
1272
handler = _mod_request.SmartServerRequestHandler(None, commands=None,
1273
root_client_path='/')
1274
self.assertFalse(handler.finished_reading)
1276
def test_hello(self):
1277
handler = self.build_handler(None)
1278
handler.dispatch_command('hello', ())
1279
self.assertEqual(('ok', '2'), handler.response.args)
1280
self.assertEqual(None, handler.response.body)
1282
def test_disable_vfs_handler_classes_via_environment(self):
1283
# VFS handler classes will raise an error from "execute" if
1284
# BZR_NO_SMART_VFS is set.
1285
handler = vfs.HasRequest(None, '/')
1286
# set environment variable after construction to make sure it's
1288
# Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
1289
# has called _captureVar, so it will be restored to the right state
1291
os.environ['BZR_NO_SMART_VFS'] = ''
1292
self.assertRaises(errors.DisabledMethod, handler.execute)
1294
def test_readonly_exception_becomes_transport_not_possible(self):
1295
"""The response for a read-only error is ('ReadOnlyError')."""
1296
handler = self.build_handler(self.get_readonly_transport())
1297
# send a mkdir for foo, with no explicit mode - should fail.
1298
handler.dispatch_command('mkdir', ('foo', ''))
1299
# and the failure should be an explicit ReadOnlyError
1300
self.assertEqual(("ReadOnlyError", ), handler.response.args)
1301
# XXX: TODO: test that other TransportNotPossible errors are
1302
# presented as TransportNotPossible - not possible to do that
1303
# until I figure out how to trigger that relatively cleanly via
1304
# the api. RBC 20060918
1306
def test_hello_has_finished_body_on_dispatch(self):
1307
"""The 'hello' command should set finished_reading."""
1308
handler = self.build_handler(None)
1309
handler.dispatch_command('hello', ())
1310
self.assertTrue(handler.finished_reading)
1311
self.assertNotEqual(None, handler.response)
1313
def test_put_bytes_non_atomic(self):
1314
"""'put_...' should set finished_reading after reading the bytes."""
1315
handler = self.build_handler(self.get_transport())
1316
handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
1317
self.assertFalse(handler.finished_reading)
1318
handler.accept_body('1234')
1319
self.assertFalse(handler.finished_reading)
1320
handler.accept_body('5678')
1321
handler.end_of_body()
1322
self.assertTrue(handler.finished_reading)
1323
self.assertEqual(('ok', ), handler.response.args)
1324
self.assertEqual(None, handler.response.body)
1326
def test_readv_accept_body(self):
1327
"""'readv' should set finished_reading after reading offsets."""
1328
self.build_tree(['a-file'])
1329
handler = self.build_handler(self.get_readonly_transport())
1330
handler.dispatch_command('readv', ('a-file', ))
1331
self.assertFalse(handler.finished_reading)
1332
handler.accept_body('2,')
1333
self.assertFalse(handler.finished_reading)
1334
handler.accept_body('3')
1335
handler.end_of_body()
1336
self.assertTrue(handler.finished_reading)
1337
self.assertEqual(('readv', ), handler.response.args)
1338
# co - nte - nt of a-file is the file contents we are extracting from.
1339
self.assertEqual('nte', handler.response.body)
1341
def test_readv_short_read_response_contents(self):
1342
"""'readv' when a short read occurs sets the response appropriately."""
1343
self.build_tree(['a-file'])
1344
handler = self.build_handler(self.get_readonly_transport())
1345
handler.dispatch_command('readv', ('a-file', ))
1346
# read beyond the end of the file.
1347
handler.accept_body('100,1')
1348
handler.end_of_body()
1349
self.assertTrue(handler.finished_reading)
1350
self.assertEqual(('ShortReadvError', './a-file', '100', '1', '0'),
1351
handler.response.args)
1352
self.assertEqual(None, handler.response.body)
1355
class RemoteTransportRegistration(tests.TestCase):
1357
def test_registration(self):
1358
t = get_transport('bzr+ssh://example.com/path')
1359
self.assertIsInstance(t, remote.RemoteSSHTransport)
1360
self.assertEqual('example.com', t._host)
1362
def test_bzr_https(self):
1363
# https://bugs.launchpad.net/bzr/+bug/128456
1364
t = get_transport('bzr+https://example.com/path')
1365
self.assertIsInstance(t, remote.RemoteHTTPTransport)
1366
self.assertStartsWith(
1367
t._http_transport.base,
1371
class TestRemoteTransport(tests.TestCase):
1373
def test_use_connection_factory(self):
1374
# We want to be able to pass a client as a parameter to RemoteTransport.
1375
input = StringIO('ok\n3\nbardone\n')
1377
client_medium = medium.SmartSimplePipesClientMedium(
1378
input, output, 'base')
1379
transport = remote.RemoteTransport(
1380
'bzr://localhost/', medium=client_medium)
1381
# Disable version detection.
1382
client_medium._protocol_version = 1
1384
# We want to make sure the client is used when the first remote
1385
# method is called. No data should have been sent, or read.
1386
self.assertEqual(0, input.tell())
1387
self.assertEqual('', output.getvalue())
1389
# Now call a method that should result in one request: as the
1390
# transport makes its own protocol instances, we check on the wire.
1391
# XXX: TODO: give the transport a protocol factory, which can make
1392
# an instrumented protocol for us.
1393
self.assertEqual('bar', transport.get_bytes('foo'))
1394
# only the needed data should have been sent/received.
1395
self.assertEqual(13, input.tell())
1396
self.assertEqual('get\x01/foo\n', output.getvalue())
1398
def test__translate_error_readonly(self):
1399
"""Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1400
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
1401
transport = remote.RemoteTransport(
1402
'bzr://localhost/', medium=client_medium)
1403
self.assertRaises(errors.TransportNotPossible,
1404
transport._translate_error, ("ReadOnlyError", ))
1407
class TestSmartProtocol(tests.TestCase):
1408
"""Base class for smart protocol tests.
1410
Each test case gets a smart_server and smart_client created during setUp().
1412
It is planned that the client can be called with self.call_client() giving
1413
it an expected server response, which will be fed into it when it tries to
1414
read. Likewise, self.call_server will call a servers method with a canned
1415
serialised client request. Output done by the client or server for these
1416
calls will be captured to self.to_server and self.to_client. Each element
1417
in the list is a write call from the client or server respectively.
1419
Subclasses can override client_protocol_class and server_protocol_class.
1422
request_encoder = None
1423
response_decoder = None
1424
server_protocol_class = None
1425
client_protocol_class = None
1427
def make_client_protocol_and_output(self, input_bytes=None):
1431
# This is very similar to
1432
# bzrlib.smart.client._SmartClient._build_client_protocol
1433
# XXX: make this use _SmartClient!
1434
if input_bytes is None:
1437
input = StringIO(input_bytes)
1439
client_medium = medium.SmartSimplePipesClientMedium(
1440
input, output, 'base')
1441
request = client_medium.get_request()
1442
if self.client_protocol_class is not None:
1443
client_protocol = self.client_protocol_class(request)
1444
return client_protocol, client_protocol, output
1446
self.assertNotEqual(None, self.request_encoder)
1447
self.assertNotEqual(None, self.response_decoder)
1448
requester = self.request_encoder(request)
1449
response_handler = message.ConventionalResponseHandler()
1450
response_protocol = self.response_decoder(
1451
response_handler, expect_version_marker=True)
1452
response_handler.setProtoAndMediumRequest(
1453
response_protocol, request)
1454
return requester, response_handler, output
1456
def make_client_protocol(self, input_bytes=None):
1457
result = self.make_client_protocol_and_output(input_bytes=input_bytes)
1458
requester, response_handler, output = result
1459
return requester, response_handler
1461
def make_server_protocol(self):
1462
out_stream = StringIO()
1463
smart_protocol = self.server_protocol_class(None, out_stream.write)
1464
return smart_protocol, out_stream
1467
super(TestSmartProtocol, self).setUp()
1468
self.response_marker = getattr(
1469
self.client_protocol_class, 'response_marker', None)
1470
self.request_marker = getattr(
1471
self.client_protocol_class, 'request_marker', None)
1473
def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1475
"""Check that smart (de)serialises offsets as expected.
1477
We check both serialisation and deserialisation at the same time
1478
to ensure that the round tripping cannot skew: both directions should
1481
:param expected_offsets: a readv offset list.
1482
:param expected_seralised: an expected serial form of the offsets.
1484
# XXX: '_deserialise_offsets' should be a method of the
1485
# SmartServerRequestProtocol in future.
1486
readv_cmd = vfs.ReadvRequest(None, '/')
1487
offsets = readv_cmd._deserialise_offsets(expected_serialised)
1488
self.assertEqual(expected_offsets, offsets)
1489
serialised = requester._serialise_offsets(offsets)
1490
self.assertEqual(expected_serialised, serialised)
1492
def build_protocol_waiting_for_body(self):
1493
smart_protocol, out_stream = self.make_server_protocol()
1494
smart_protocol._has_dispatched = True
1495
smart_protocol.request = _mod_request.SmartServerRequestHandler(
1496
None, _mod_request.request_handlers, '/')
1497
class FakeCommand(object):
1498
def do_body(cmd, body_bytes):
1499
self.end_received = True
1500
self.assertEqual('abcdefg', body_bytes)
1501
return _mod_request.SuccessfulSmartServerResponse(('ok', ))
1502
smart_protocol.request._command = FakeCommand()
1503
# Call accept_bytes to make sure that internal state like _body_decoder
1504
# is initialised. This test should probably be given a clearer
1505
# interface to work with that will not cause this inconsistency.
1506
# -- Andrew Bennetts, 2006-09-28
1507
smart_protocol.accept_bytes('')
1508
return smart_protocol
1510
def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1512
"""Assert that each input_tuple serialises as expected_bytes, and the
1513
bytes deserialise as expected_tuple.
1515
# check the encoding of the server for all input_tuples matches
1517
for input_tuple in input_tuples:
1518
server_protocol, server_output = self.make_server_protocol()
1519
server_protocol._send_response(
1520
_mod_request.SuccessfulSmartServerResponse(input_tuple))
1521
self.assertEqual(expected_bytes, server_output.getvalue())
1522
# check the decoding of the client smart_protocol from expected_bytes:
1523
requester, response_handler = self.make_client_protocol(expected_bytes)
1524
requester.call('foo')
1525
self.assertEqual(expected_tuple, response_handler.read_response_tuple())
1528
class CommonSmartProtocolTestMixin(object):
1530
def test_connection_closed_reporting(self):
1531
requester, response_handler = self.make_client_protocol()
1532
requester.call('hello')
1533
ex = self.assertRaises(errors.ConnectionReset,
1534
response_handler.read_response_tuple)
1535
self.assertEqual("Connection closed: "
1536
"please check connectivity and permissions "
1537
"(and try -Dhpss if further diagnosis is required)", str(ex))
1539
def test_server_offset_serialisation(self):
1540
"""The Smart protocol serialises offsets as a comma and \n string.
1542
We check a number of boundary cases are as expected: empty, one offset,
1543
one with the order of reads not increasing (an out of order read), and
1544
one that should coalesce.
1546
requester, response_handler = self.make_client_protocol()
1547
self.assertOffsetSerialisation([], '', requester)
1548
self.assertOffsetSerialisation([(1,2)], '1,2', requester)
1549
self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1551
self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1552
'1,2\n3,4\n100,200', requester)
1555
class TestVersionOneFeaturesInProtocolOne(
1556
TestSmartProtocol, CommonSmartProtocolTestMixin):
1557
"""Tests for version one smart protocol features as implemeted by version
1560
client_protocol_class = protocol.SmartClientRequestProtocolOne
1561
server_protocol_class = protocol.SmartServerRequestProtocolOne
1563
def test_construct_version_one_server_protocol(self):
1564
smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1565
self.assertEqual('', smart_protocol.unused_data)
1566
self.assertEqual('', smart_protocol.in_buffer)
1567
self.assertFalse(smart_protocol._has_dispatched)
1568
self.assertEqual(1, smart_protocol.next_read_size())
1570
def test_construct_version_one_client_protocol(self):
1571
# we can construct a client protocol from a client medium request
1573
client_medium = medium.SmartSimplePipesClientMedium(
1574
None, output, 'base')
1575
request = client_medium.get_request()
1576
client_protocol = protocol.SmartClientRequestProtocolOne(request)
1578
def test_accept_bytes_of_bad_request_to_protocol(self):
1579
out_stream = StringIO()
1580
smart_protocol = protocol.SmartServerRequestProtocolOne(
1581
None, out_stream.write)
1582
smart_protocol.accept_bytes('abc')
1583
self.assertEqual('abc', smart_protocol.in_buffer)
1584
smart_protocol.accept_bytes('\n')
1586
"error\x01Generic bzr smart protocol error: bad request 'abc'\n",
1587
out_stream.getvalue())
1588
self.assertTrue(smart_protocol._has_dispatched)
1589
self.assertEqual(0, smart_protocol.next_read_size())
1591
def test_accept_body_bytes_to_protocol(self):
1592
protocol = self.build_protocol_waiting_for_body()
1593
self.assertEqual(6, protocol.next_read_size())
1594
protocol.accept_bytes('7\nabc')
1595
self.assertEqual(9, protocol.next_read_size())
1596
protocol.accept_bytes('defgd')
1597
protocol.accept_bytes('one\n')
1598
self.assertEqual(0, protocol.next_read_size())
1599
self.assertTrue(self.end_received)
1601
def test_accept_request_and_body_all_at_once(self):
1602
self._captureVar('BZR_NO_SMART_VFS', None)
1603
mem_transport = memory.MemoryTransport()
1604
mem_transport.put_bytes('foo', 'abcdefghij')
1605
out_stream = StringIO()
1606
smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
1608
smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1609
self.assertEqual(0, smart_protocol.next_read_size())
1610
self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1611
self.assertEqual('', smart_protocol.unused_data)
1612
self.assertEqual('', smart_protocol.in_buffer)
1614
def test_accept_excess_bytes_are_preserved(self):
1615
out_stream = StringIO()
1616
smart_protocol = protocol.SmartServerRequestProtocolOne(
1617
None, out_stream.write)
1618
smart_protocol.accept_bytes('hello\nhello\n')
1619
self.assertEqual("ok\x012\n", out_stream.getvalue())
1620
self.assertEqual("hello\n", smart_protocol.unused_data)
1621
self.assertEqual("", smart_protocol.in_buffer)
1623
def test_accept_excess_bytes_after_body(self):
1624
protocol = self.build_protocol_waiting_for_body()
1625
protocol.accept_bytes('7\nabcdefgdone\nX')
1626
self.assertTrue(self.end_received)
1627
self.assertEqual("X", protocol.unused_data)
1628
self.assertEqual("", protocol.in_buffer)
1629
protocol.accept_bytes('Y')
1630
self.assertEqual("XY", protocol.unused_data)
1631
self.assertEqual("", protocol.in_buffer)
1633
def test_accept_excess_bytes_after_dispatch(self):
1634
out_stream = StringIO()
1635
smart_protocol = protocol.SmartServerRequestProtocolOne(
1636
None, out_stream.write)
1637
smart_protocol.accept_bytes('hello\n')
1638
self.assertEqual("ok\x012\n", out_stream.getvalue())
1639
smart_protocol.accept_bytes('hel')
1640
self.assertEqual("hel", smart_protocol.unused_data)
1641
smart_protocol.accept_bytes('lo\n')
1642
self.assertEqual("hello\n", smart_protocol.unused_data)
1643
self.assertEqual("", smart_protocol.in_buffer)
1645
def test__send_response_sets_finished_reading(self):
1646
smart_protocol = protocol.SmartServerRequestProtocolOne(
1647
None, lambda x: None)
1648
self.assertEqual(1, smart_protocol.next_read_size())
1649
smart_protocol._send_response(
1650
_mod_request.SuccessfulSmartServerResponse(('x',)))
1651
self.assertEqual(0, smart_protocol.next_read_size())
1653
def test__send_response_errors_with_base_response(self):
1654
"""Ensure that only the Successful/Failed subclasses are used."""
1655
smart_protocol = protocol.SmartServerRequestProtocolOne(
1656
None, lambda x: None)
1657
self.assertRaises(AttributeError, smart_protocol._send_response,
1658
_mod_request.SmartServerResponse(('x',)))
1660
def test_query_version(self):
1661
"""query_version on a SmartClientProtocolOne should return a number.
1663
The protocol provides the query_version because the domain level clients
1664
may all need to be able to probe for capabilities.
1666
# What we really want to test here is that SmartClientProtocolOne calls
1667
# accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1668
# response of tuple-encoded (ok, 1). Also, seperately we should test
1669
# the error if the response is a non-understood version.
1670
input = StringIO('ok\x012\n')
1672
client_medium = medium.SmartSimplePipesClientMedium(
1673
input, output, 'base')
1674
request = client_medium.get_request()
1675
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1676
self.assertEqual(2, smart_protocol.query_version())
1678
def test_client_call_empty_response(self):
1679
# protocol.call() can get back an empty tuple as a response. This occurs
1680
# when the parsed line is an empty line, and results in a tuple with
1681
# one element - an empty string.
1682
self.assertServerToClientEncoding('\n', ('', ), [(), ('', )])
1684
def test_client_call_three_element_response(self):
1685
# protocol.call() can get back tuples of other lengths. A three element
1686
# tuple should be unpacked as three strings.
1687
self.assertServerToClientEncoding('a\x01b\x0134\n', ('a', 'b', '34'),
1690
def test_client_call_with_body_bytes_uploads(self):
1691
# protocol.call_with_body_bytes should length-prefix the bytes onto the
1693
expected_bytes = "foo\n7\nabcdefgdone\n"
1694
input = StringIO("\n")
1696
client_medium = medium.SmartSimplePipesClientMedium(
1697
input, output, 'base')
1698
request = client_medium.get_request()
1699
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1700
smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1701
self.assertEqual(expected_bytes, output.getvalue())
1703
def test_client_call_with_body_readv_array(self):
1704
# protocol.call_with_upload should encode the readv array and then
1705
# length-prefix the bytes onto the wire.
1706
expected_bytes = "foo\n7\n1,2\n5,6done\n"
1707
input = StringIO("\n")
1709
client_medium = medium.SmartSimplePipesClientMedium(
1710
input, output, 'base')
1711
request = client_medium.get_request()
1712
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1713
smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1714
self.assertEqual(expected_bytes, output.getvalue())
1716
def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
1718
input = StringIO(server_bytes)
1720
client_medium = medium.SmartSimplePipesClientMedium(
1721
input, output, 'base')
1722
request = client_medium.get_request()
1723
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1724
smart_protocol.call('foo')
1726
errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
1727
# The request has been finished. There is no body to read, and
1728
# attempts to read one will fail.
1730
errors.ReadingCompleted, smart_protocol.read_body_bytes)
1732
def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
1733
"""read_response_tuple raises UnknownSmartMethod if the response says
1734
the server did not recognise the request.
1737
"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
1738
self._test_client_read_response_tuple_raises_UnknownSmartMethod(
1741
def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self):
1742
"""read_response_tuple also raises UnknownSmartMethod if the response
1743
from a bzr 0.11 says the server did not recognise the request.
1745
(bzr 0.11 sends a slightly different error message to later versions.)
1748
"error\x01Generic bzr smart protocol error: bad request u'foo'\n")
1749
self._test_client_read_response_tuple_raises_UnknownSmartMethod(
1752
def test_client_read_body_bytes_all(self):
1753
# read_body_bytes should decode the body bytes from the wire into
1755
expected_bytes = "1234567"
1756
server_bytes = "ok\n7\n1234567done\n"
1757
input = StringIO(server_bytes)
1759
client_medium = medium.SmartSimplePipesClientMedium(
1760
input, output, 'base')
1761
request = client_medium.get_request()
1762
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1763
smart_protocol.call('foo')
1764
smart_protocol.read_response_tuple(True)
1765
self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1767
def test_client_read_body_bytes_incremental(self):
1768
# test reading a few bytes at a time from the body
1769
# XXX: possibly we should test dribbling the bytes into the stringio
1770
# to make the state machine work harder: however, as we use the
1771
# LengthPrefixedBodyDecoder that is already well tested - we can skip
1773
expected_bytes = "1234567"
1774
server_bytes = "ok\n7\n1234567done\n"
1775
input = StringIO(server_bytes)
1777
client_medium = medium.SmartSimplePipesClientMedium(
1778
input, output, 'base')
1779
request = client_medium.get_request()
1780
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1781
smart_protocol.call('foo')
1782
smart_protocol.read_response_tuple(True)
1783
self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
1784
self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
1785
self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
1786
self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1788
def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1789
# cancelling the expected body needs to finish the request, but not
1790
# read any more bytes.
1791
expected_bytes = "1234567"
1792
server_bytes = "ok\n7\n1234567done\n"
1793
input = StringIO(server_bytes)
1795
client_medium = medium.SmartSimplePipesClientMedium(
1796
input, output, 'base')
1797
request = client_medium.get_request()
1798
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1799
smart_protocol.call('foo')
1800
smart_protocol.read_response_tuple(True)
1801
smart_protocol.cancel_read_body()
1802
self.assertEqual(3, input.tell())
1804
errors.ReadingCompleted, smart_protocol.read_body_bytes)
1806
def test_client_read_body_bytes_interrupted_connection(self):
1807
server_bytes = "ok\n999\nincomplete body"
1808
input = StringIO(server_bytes)
1810
client_medium = medium.SmartSimplePipesClientMedium(
1811
input, output, 'base')
1812
request = client_medium.get_request()
1813
smart_protocol = self.client_protocol_class(request)
1814
smart_protocol.call('foo')
1815
smart_protocol.read_response_tuple(True)
1817
errors.ConnectionReset, smart_protocol.read_body_bytes)
1820
class TestVersionOneFeaturesInProtocolTwo(
1821
TestSmartProtocol, CommonSmartProtocolTestMixin):
1822
"""Tests for version one smart protocol features as implemeted by version
1826
client_protocol_class = protocol.SmartClientRequestProtocolTwo
1827
server_protocol_class = protocol.SmartServerRequestProtocolTwo
1829
def test_construct_version_two_server_protocol(self):
1830
smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
1831
self.assertEqual('', smart_protocol.unused_data)
1832
self.assertEqual('', smart_protocol.in_buffer)
1833
self.assertFalse(smart_protocol._has_dispatched)
1834
self.assertEqual(1, smart_protocol.next_read_size())
1836
def test_construct_version_two_client_protocol(self):
1837
# we can construct a client protocol from a client medium request
1839
client_medium = medium.SmartSimplePipesClientMedium(
1840
None, output, 'base')
1841
request = client_medium.get_request()
1842
client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1844
def test_accept_bytes_of_bad_request_to_protocol(self):
1845
out_stream = StringIO()
1846
smart_protocol = self.server_protocol_class(None, out_stream.write)
1847
smart_protocol.accept_bytes('abc')
1848
self.assertEqual('abc', smart_protocol.in_buffer)
1849
smart_protocol.accept_bytes('\n')
1851
self.response_marker +
1852
"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1853
out_stream.getvalue())
1854
self.assertTrue(smart_protocol._has_dispatched)
1855
self.assertEqual(0, smart_protocol.next_read_size())
1857
def test_accept_body_bytes_to_protocol(self):
1858
protocol = self.build_protocol_waiting_for_body()
1859
self.assertEqual(6, protocol.next_read_size())
1860
protocol.accept_bytes('7\nabc')
1861
self.assertEqual(9, protocol.next_read_size())
1862
protocol.accept_bytes('defgd')
1863
protocol.accept_bytes('one\n')
1864
self.assertEqual(0, protocol.next_read_size())
1865
self.assertTrue(self.end_received)
1867
def test_accept_request_and_body_all_at_once(self):
1868
self._captureVar('BZR_NO_SMART_VFS', None)
1869
mem_transport = memory.MemoryTransport()
1870
mem_transport.put_bytes('foo', 'abcdefghij')
1871
out_stream = StringIO()
1872
smart_protocol = self.server_protocol_class(
1873
mem_transport, out_stream.write)
1874
smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1875
self.assertEqual(0, smart_protocol.next_read_size())
1876
self.assertEqual(self.response_marker +
1877
'success\nreadv\n3\ndefdone\n',
1878
out_stream.getvalue())
1879
self.assertEqual('', smart_protocol.unused_data)
1880
self.assertEqual('', smart_protocol.in_buffer)
1882
def test_accept_excess_bytes_are_preserved(self):
1883
out_stream = StringIO()
1884
smart_protocol = self.server_protocol_class(None, out_stream.write)
1885
smart_protocol.accept_bytes('hello\nhello\n')
1886
self.assertEqual(self.response_marker + "success\nok\x012\n",
1887
out_stream.getvalue())
1888
self.assertEqual("hello\n", smart_protocol.unused_data)
1889
self.assertEqual("", smart_protocol.in_buffer)
1891
def test_accept_excess_bytes_after_body(self):
1892
# The excess bytes look like the start of another request.
1893
server_protocol = self.build_protocol_waiting_for_body()
1894
server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
1895
self.assertTrue(self.end_received)
1896
self.assertEqual(self.response_marker,
1897
server_protocol.unused_data)
1898
self.assertEqual("", server_protocol.in_buffer)
1899
server_protocol.accept_bytes('Y')
1900
self.assertEqual(self.response_marker + "Y",
1901
server_protocol.unused_data)
1902
self.assertEqual("", server_protocol.in_buffer)
1904
def test_accept_excess_bytes_after_dispatch(self):
1905
out_stream = StringIO()
1906
smart_protocol = self.server_protocol_class(None, out_stream.write)
1907
smart_protocol.accept_bytes('hello\n')
1908
self.assertEqual(self.response_marker + "success\nok\x012\n",
1909
out_stream.getvalue())
1910
smart_protocol.accept_bytes(self.request_marker + 'hel')
1911
self.assertEqual(self.request_marker + "hel",
1912
smart_protocol.unused_data)
1913
smart_protocol.accept_bytes('lo\n')
1914
self.assertEqual(self.request_marker + "hello\n",
1915
smart_protocol.unused_data)
1916
self.assertEqual("", smart_protocol.in_buffer)
1918
def test__send_response_sets_finished_reading(self):
1919
smart_protocol = self.server_protocol_class(None, lambda x: None)
1920
self.assertEqual(1, smart_protocol.next_read_size())
1921
smart_protocol._send_response(
1922
_mod_request.SuccessfulSmartServerResponse(('x',)))
1923
self.assertEqual(0, smart_protocol.next_read_size())
1925
def test__send_response_errors_with_base_response(self):
1926
"""Ensure that only the Successful/Failed subclasses are used."""
1927
smart_protocol = self.server_protocol_class(None, lambda x: None)
1928
self.assertRaises(AttributeError, smart_protocol._send_response,
1929
_mod_request.SmartServerResponse(('x',)))
1931
def test_query_version(self):
1932
"""query_version on a SmartClientProtocolTwo should return a number.
1934
The protocol provides the query_version because the domain level clients
1935
may all need to be able to probe for capabilities.
1937
# What we really want to test here is that SmartClientProtocolTwo calls
1938
# accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1939
# response of tuple-encoded (ok, 1). Also, seperately we should test
1940
# the error if the response is a non-understood version.
1941
input = StringIO(self.response_marker + 'success\nok\x012\n')
1943
client_medium = medium.SmartSimplePipesClientMedium(
1944
input, output, 'base')
1945
request = client_medium.get_request()
1946
smart_protocol = self.client_protocol_class(request)
1947
self.assertEqual(2, smart_protocol.query_version())
1949
def test_client_call_empty_response(self):
1950
# protocol.call() can get back an empty tuple as a response. This occurs
1951
# when the parsed line is an empty line, and results in a tuple with
1952
# one element - an empty string.
1953
self.assertServerToClientEncoding(
1954
self.response_marker + 'success\n\n', ('', ), [(), ('', )])
1956
def test_client_call_three_element_response(self):
1957
# protocol.call() can get back tuples of other lengths. A three element
1958
# tuple should be unpacked as three strings.
1959
self.assertServerToClientEncoding(
1960
self.response_marker + 'success\na\x01b\x0134\n',
1964
def test_client_call_with_body_bytes_uploads(self):
1965
# protocol.call_with_body_bytes should length-prefix the bytes onto the
1967
expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
1968
input = StringIO("\n")
1970
client_medium = medium.SmartSimplePipesClientMedium(
1971
input, output, 'base')
1972
request = client_medium.get_request()
1973
smart_protocol = self.client_protocol_class(request)
1974
smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1975
self.assertEqual(expected_bytes, output.getvalue())
1977
def test_client_call_with_body_readv_array(self):
1978
# protocol.call_with_upload should encode the readv array and then
1979
# length-prefix the bytes onto the wire.
1980
expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
1981
input = StringIO("\n")
1983
client_medium = medium.SmartSimplePipesClientMedium(
1984
input, output, 'base')
1985
request = client_medium.get_request()
1986
smart_protocol = self.client_protocol_class(request)
1987
smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1988
self.assertEqual(expected_bytes, output.getvalue())
1990
def test_client_read_body_bytes_all(self):
1991
# read_body_bytes should decode the body bytes from the wire into
1993
expected_bytes = "1234567"
1994
server_bytes = (self.response_marker +
1995
"success\nok\n7\n1234567done\n")
1996
input = StringIO(server_bytes)
1998
client_medium = medium.SmartSimplePipesClientMedium(
1999
input, output, 'base')
2000
request = client_medium.get_request()
2001
smart_protocol = self.client_protocol_class(request)
2002
smart_protocol.call('foo')
2003
smart_protocol.read_response_tuple(True)
2004
self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
2006
def test_client_read_body_bytes_incremental(self):
2007
# test reading a few bytes at a time from the body
2008
# XXX: possibly we should test dribbling the bytes into the stringio
2009
# to make the state machine work harder: however, as we use the
2010
# LengthPrefixedBodyDecoder that is already well tested - we can skip
2012
expected_bytes = "1234567"
2013
server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
2014
input = StringIO(server_bytes)
2016
client_medium = medium.SmartSimplePipesClientMedium(
2017
input, output, 'base')
2018
request = client_medium.get_request()
2019
smart_protocol = self.client_protocol_class(request)
2020
smart_protocol.call('foo')
2021
smart_protocol.read_response_tuple(True)
2022
self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
2023
self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
2024
self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
2025
self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
2027
def test_client_cancel_read_body_does_not_eat_body_bytes(self):
2028
# cancelling the expected body needs to finish the request, but not
2029
# read any more bytes.
2030
server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
2031
input = StringIO(server_bytes)
2033
client_medium = medium.SmartSimplePipesClientMedium(
2034
input, output, 'base')
2035
request = client_medium.get_request()
2036
smart_protocol = self.client_protocol_class(request)
2037
smart_protocol.call('foo')
2038
smart_protocol.read_response_tuple(True)
2039
smart_protocol.cancel_read_body()
2040
self.assertEqual(len(self.response_marker + 'success\nok\n'),
2043
errors.ReadingCompleted, smart_protocol.read_body_bytes)
2045
def test_client_read_body_bytes_interrupted_connection(self):
2046
server_bytes = (self.response_marker +
2047
"success\nok\n999\nincomplete body")
2048
input = StringIO(server_bytes)
2050
client_medium = medium.SmartSimplePipesClientMedium(
2051
input, output, 'base')
2052
request = client_medium.get_request()
2053
smart_protocol = self.client_protocol_class(request)
2054
smart_protocol.call('foo')
2055
smart_protocol.read_response_tuple(True)
2057
errors.ConnectionReset, smart_protocol.read_body_bytes)
2060
class TestSmartProtocolTwoSpecificsMixin(object):
2062
def assertBodyStreamSerialisation(self, expected_serialisation,
2064
"""Assert that body_stream is serialised as expected_serialisation."""
2065
out_stream = StringIO()
2066
protocol._send_stream(body_stream, out_stream.write)
2067
self.assertEqual(expected_serialisation, out_stream.getvalue())
2069
def assertBodyStreamRoundTrips(self, body_stream):
2070
"""Assert that body_stream is the same after being serialised and
2073
out_stream = StringIO()
2074
protocol._send_stream(body_stream, out_stream.write)
2075
decoder = protocol.ChunkedBodyDecoder()
2076
decoder.accept_bytes(out_stream.getvalue())
2077
decoded_stream = list(iter(decoder.read_next_chunk, None))
2078
self.assertEqual(body_stream, decoded_stream)
2080
def test_body_stream_serialisation_empty(self):
2081
"""A body_stream with no bytes can be serialised."""
2082
self.assertBodyStreamSerialisation('chunked\nEND\n', [])
2083
self.assertBodyStreamRoundTrips([])
2085
def test_body_stream_serialisation(self):
2086
stream = ['chunk one', 'chunk two', 'chunk three']
2087
self.assertBodyStreamSerialisation(
2088
'chunked\n' + '9\nchunk one' + '9\nchunk two' + 'b\nchunk three' +
2091
self.assertBodyStreamRoundTrips(stream)
2093
def test_body_stream_with_empty_element_serialisation(self):
2094
"""A body stream can include ''.
2096
The empty string can be transmitted like any other string.
2098
stream = ['', 'chunk']
2099
self.assertBodyStreamSerialisation(
2100
'chunked\n' + '0\n' + '5\nchunk' + 'END\n', stream)
2101
self.assertBodyStreamRoundTrips(stream)
2103
def test_body_stream_error_serialistion(self):
2104
stream = ['first chunk',
2105
_mod_request.FailedSmartServerResponse(
2106
('FailureName', 'failure arg'))]
2108
'chunked\n' + 'b\nfirst chunk' +
2109
'ERR\n' + 'b\nFailureName' + 'b\nfailure arg' +
2111
self.assertBodyStreamSerialisation(expected_bytes, stream)
2112
self.assertBodyStreamRoundTrips(stream)
2114
def test__send_response_includes_failure_marker(self):
2115
"""FailedSmartServerResponse have 'failed\n' after the version."""
2116
out_stream = StringIO()
2117
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2118
None, out_stream.write)
2119
smart_protocol._send_response(
2120
_mod_request.FailedSmartServerResponse(('x',)))
2121
self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
2122
out_stream.getvalue())
2124
def test__send_response_includes_success_marker(self):
2125
"""SuccessfulSmartServerResponse have 'success\n' after the version."""
2126
out_stream = StringIO()
2127
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2128
None, out_stream.write)
2129
smart_protocol._send_response(
2130
_mod_request.SuccessfulSmartServerResponse(('x',)))
2131
self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
2132
out_stream.getvalue())
2134
def test__send_response_with_body_stream_sets_finished_reading(self):
2135
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2136
None, lambda x: None)
2137
self.assertEqual(1, smart_protocol.next_read_size())
2138
smart_protocol._send_response(
2139
_mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
2140
self.assertEqual(0, smart_protocol.next_read_size())
2142
def test_streamed_body_bytes(self):
2143
body_header = 'chunked\n'
2144
two_body_chunks = "4\n1234" + "3\n567"
2145
body_terminator = "END\n"
2146
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2147
"success\nok\n" + body_header + two_body_chunks +
2149
input = StringIO(server_bytes)
2151
client_medium = medium.SmartSimplePipesClientMedium(
2152
input, output, 'base')
2153
request = client_medium.get_request()
2154
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2155
smart_protocol.call('foo')
2156
smart_protocol.read_response_tuple(True)
2157
stream = smart_protocol.read_streamed_body()
2158
self.assertEqual(['1234', '567'], list(stream))
2160
def test_read_streamed_body_error(self):
2161
"""When a stream is interrupted by an error..."""
2162
body_header = 'chunked\n'
2163
a_body_chunk = '4\naaaa'
2164
err_signal = 'ERR\n'
2165
err_chunks = 'a\nerror arg1' + '4\narg2'
2167
body = body_header + a_body_chunk + err_signal + err_chunks + finish
2168
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2169
"success\nok\n" + body)
2170
input = StringIO(server_bytes)
2172
client_medium = medium.SmartSimplePipesClientMedium(
2173
input, output, 'base')
2174
smart_request = client_medium.get_request()
2175
smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
2176
smart_protocol.call('foo')
2177
smart_protocol.read_response_tuple(True)
2180
_mod_request.FailedSmartServerResponse(('error arg1', 'arg2'))]
2181
stream = smart_protocol.read_streamed_body()
2182
self.assertEqual(expected_chunks, list(stream))
2184
def test_streamed_body_bytes_interrupted_connection(self):
2185
body_header = 'chunked\n'
2186
incomplete_body_chunk = "9999\nincomplete chunk"
2187
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2188
"success\nok\n" + body_header + incomplete_body_chunk)
2189
input = StringIO(server_bytes)
2191
client_medium = medium.SmartSimplePipesClientMedium(
2192
input, output, 'base')
2193
request = client_medium.get_request()
2194
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2195
smart_protocol.call('foo')
2196
smart_protocol.read_response_tuple(True)
2197
stream = smart_protocol.read_streamed_body()
2198
self.assertRaises(errors.ConnectionReset, stream.next)
2200
def test_client_read_response_tuple_sets_response_status(self):
2201
server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
2202
input = StringIO(server_bytes)
2204
client_medium = medium.SmartSimplePipesClientMedium(
2205
input, output, 'base')
2206
request = client_medium.get_request()
2207
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2208
smart_protocol.call('foo')
2209
smart_protocol.read_response_tuple(False)
2210
self.assertEqual(True, smart_protocol.response_status)
2212
def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
2213
"""read_response_tuple raises UnknownSmartMethod if the response says
2214
the server did not recognise the request.
2217
protocol.RESPONSE_VERSION_TWO +
2219
"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
2220
input = StringIO(server_bytes)
2222
client_medium = medium.SmartSimplePipesClientMedium(
2223
input, output, 'base')
2224
request = client_medium.get_request()
2225
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2226
smart_protocol.call('foo')
2228
errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
2229
# The request has been finished. There is no body to read, and
2230
# attempts to read one will fail.
2232
errors.ReadingCompleted, smart_protocol.read_body_bytes)
2235
class TestSmartProtocolTwoSpecifics(
2236
TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
2237
"""Tests for aspects of smart protocol version two that are unique to
2240
Thus tests involving body streams and success/failure markers belong here.
2243
client_protocol_class = protocol.SmartClientRequestProtocolTwo
2244
server_protocol_class = protocol.SmartServerRequestProtocolTwo
2247
class TestVersionOneFeaturesInProtocolThree(
2248
TestSmartProtocol, CommonSmartProtocolTestMixin):
2249
"""Tests for version one smart protocol features as implemented by version
2253
request_encoder = protocol.ProtocolThreeRequester
2254
response_decoder = protocol.ProtocolThreeDecoder
2255
# build_server_protocol_three is a function, so we can't set it as a class
2256
# attribute directly, because then Python will assume it is actually a
2257
# method. So we make server_protocol_class be a static method, rather than
2259
# "server_protocol_class = protocol.build_server_protocol_three".
2260
server_protocol_class = staticmethod(protocol.build_server_protocol_three)
2263
super(TestVersionOneFeaturesInProtocolThree, self).setUp()
2264
self.response_marker = protocol.MESSAGE_VERSION_THREE
2265
self.request_marker = protocol.MESSAGE_VERSION_THREE
2267
def test_construct_version_three_server_protocol(self):
2268
smart_protocol = protocol.ProtocolThreeDecoder(None)
2269
self.assertEqual('', smart_protocol.unused_data)
2270
self.assertEqual([], smart_protocol._in_buffer_list)
2271
self.assertEqual(0, smart_protocol._in_buffer_len)
2272
self.assertFalse(smart_protocol._has_dispatched)
2273
# The protocol starts by expecting four bytes, a length prefix for the
2275
self.assertEqual(4, smart_protocol.next_read_size())
2278
class NoOpRequest(_mod_request.SmartServerRequest):
2281
return _mod_request.SuccessfulSmartServerResponse(())
2283
dummy_registry = {'ARG': NoOpRequest}
2286
class LoggingMessageHandler(object):
2291
def _log(self, *args):
2292
self.event_log.append(args)
2294
def headers_received(self, headers):
2295
self._log('headers', headers)
2297
def protocol_error(self, exception):
2298
self._log('protocol_error', exception)
2300
def byte_part_received(self, byte):
2301
self._log('byte', byte)
2303
def bytes_part_received(self, bytes):
2304
self._log('bytes', bytes)
2306
def structure_part_received(self, structure):
2307
self._log('structure', structure)
2309
def end_received(self):
2313
class TestProtocolThree(TestSmartProtocol):
2314
"""Tests for v3 of the server-side protocol."""
2316
request_encoder = protocol.ProtocolThreeRequester
2317
response_decoder = protocol.ProtocolThreeDecoder
2318
server_protocol_class = protocol.ProtocolThreeDecoder
2320
def test_trivial_request(self):
2321
"""Smoke test for the simplest possible v3 request: empty headers, no
2325
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2327
request_bytes = headers + end
2328
smart_protocol = self.server_protocol_class(LoggingMessageHandler())
2329
smart_protocol.accept_bytes(request_bytes)
2330
self.assertEqual(0, smart_protocol.next_read_size())
2331
self.assertEqual('', smart_protocol.unused_data)
2333
def make_protocol_expecting_message_part(self):
2334
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2335
message_handler = LoggingMessageHandler()
2336
smart_protocol = self.server_protocol_class(message_handler)
2337
smart_protocol.accept_bytes(headers)
2338
# Clear the event log
2339
del message_handler.event_log[:]
2340
return smart_protocol, message_handler.event_log
2342
def test_decode_one_byte(self):
2343
"""The protocol can decode a 'one byte' message part."""
2344
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2345
smart_protocol.accept_bytes('ox')
2346
self.assertEqual([('byte', 'x')], event_log)
2348
def test_decode_bytes(self):
2349
"""The protocol can decode a 'bytes' message part."""
2350
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2351
smart_protocol.accept_bytes(
2352
'b' # message part kind
2353
'\0\0\0\x07' # length prefix
2356
self.assertEqual([('bytes', 'payload')], event_log)
2358
def test_decode_structure(self):
2359
"""The protocol can decode a 'structure' message part."""
2360
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2361
smart_protocol.accept_bytes(
2362
's' # message part kind
2363
'\0\0\0\x07' # length prefix
2366
self.assertEqual([('structure', ['ARG'])], event_log)
2368
def test_decode_multiple_bytes(self):
2369
"""The protocol can decode a multiple 'bytes' message parts."""
2370
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2371
smart_protocol.accept_bytes(
2372
'b' # message part kind
2373
'\0\0\0\x05' # length prefix
2375
'b' # message part kind
2380
[('bytes', 'first'), ('bytes', 'second')], event_log)
2383
class TestConventionalResponseHandler(tests.TestCase):
2385
def make_response_handler(self, response_bytes):
2386
from bzrlib.smart.message import ConventionalResponseHandler
2387
response_handler = ConventionalResponseHandler()
2388
protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
2389
# put decoder in desired state (waiting for message parts)
2390
protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
2392
client_medium = medium.SmartSimplePipesClientMedium(
2393
StringIO(response_bytes), output, 'base')
2394
medium_request = client_medium.get_request()
2395
medium_request.finished_writing()
2396
response_handler.setProtoAndMediumRequest(
2397
protocol_decoder, medium_request)
2398
return response_handler
2400
def test_body_stream_interrupted_by_error(self):
2401
interrupted_body_stream = (
2402
'oS' # successful response
2403
's\0\0\0\x02le' # empty args
2404
'b\0\0\0\x09chunk one' # first chunk
2405
'b\0\0\0\x09chunk two' # second chunk
2407
's\0\0\0\x0el5:error3:abce' # bencoded error
2410
response_handler = self.make_response_handler(interrupted_body_stream)
2411
stream = response_handler.read_streamed_body()
2412
self.assertEqual('chunk one', stream.next())
2413
self.assertEqual('chunk two', stream.next())
2414
exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2415
self.assertEqual(('error', 'abc'), exc.error_tuple)
2417
def test_body_stream_interrupted_by_connection_lost(self):
2418
interrupted_body_stream = (
2419
'oS' # successful response
2420
's\0\0\0\x02le' # empty args
2421
'b\0\0\xff\xffincomplete chunk')
2422
response_handler = self.make_response_handler(interrupted_body_stream)
2423
stream = response_handler.read_streamed_body()
2424
self.assertRaises(errors.ConnectionReset, stream.next)
2426
def test_read_body_bytes_interrupted_by_connection_lost(self):
2427
interrupted_body_stream = (
2428
'oS' # successful response
2429
's\0\0\0\x02le' # empty args
2430
'b\0\0\xff\xffincomplete chunk')
2431
response_handler = self.make_response_handler(interrupted_body_stream)
2433
errors.ConnectionReset, response_handler.read_body_bytes)
2436
class TestMessageHandlerErrors(tests.TestCase):
2437
"""Tests for v3 that unrecognised (but well-formed) requests/responses are
2438
still fully read off the wire, so that subsequent requests/responses on the
2439
same medium can be decoded.
2442
def test_non_conventional_request(self):
2443
"""ConventionalRequestHandler (the default message handler on the
2444
server side) will reject an unconventional message, but still consume
2445
all the bytes of that message and signal when it has done so.
2447
This is what allows a server to continue to accept requests after the
2448
client sends a completely unrecognised request.
2450
# Define an invalid request (but one that is a well-formed message).
2451
# This particular invalid request not only lacks the mandatory
2452
# verb+args tuple, it has a single-byte part, which is forbidden. In
2453
# fact it has that part twice, to trigger multiple errors.
2455
protocol.MESSAGE_VERSION_THREE + # protocol version marker
2456
'\0\0\0\x02de' + # empty headers
2457
'oX' + # a single byte part: 'X'. ConventionalRequestHandler will
2458
# error at this part.
2460
'e' # end of message
2463
to_server = StringIO(invalid_request)
2464
from_server = StringIO()
2465
transport = memory.MemoryTransport('memory:///')
2466
server = medium.SmartServerPipeStreamMedium(
2467
to_server, from_server, transport)
2468
proto = server._build_protocol()
2469
message_handler = proto.message_handler
2470
server._serve_one_request(proto)
2471
# All the bytes have been read from the medium...
2472
self.assertEqual('', to_server.read())
2473
# ...and the protocol decoder has consumed all the bytes, and has
2475
self.assertEqual('', proto.unused_data)
2476
self.assertEqual(0, proto.next_read_size())
2479
class InstrumentedRequestHandler(object):
2480
"""Test Double of SmartServerRequestHandler."""
2485
def body_chunk_received(self, chunk_bytes):
2486
self.calls.append(('body_chunk_received', chunk_bytes))
2488
def no_body_received(self):
2489
self.calls.append(('no_body_received',))
2491
def prefixed_body_received(self, body_bytes):
2492
self.calls.append(('prefixed_body_received', body_bytes))
2494
def end_received(self):
2495
self.calls.append(('end_received',))
2498
class StubRequest(object):
2500
def finished_reading(self):
2504
class TestClientDecodingProtocolThree(TestSmartProtocol):
2505
"""Tests for v3 of the client-side protocol decoding."""
2507
def make_logging_response_decoder(self):
2508
"""Make v3 response decoder using a test response handler."""
2509
response_handler = LoggingMessageHandler()
2510
decoder = protocol.ProtocolThreeDecoder(response_handler)
2511
return decoder, response_handler
2513
def make_conventional_response_decoder(self):
2514
"""Make v3 response decoder using a conventional response handler."""
2515
response_handler = message.ConventionalResponseHandler()
2516
decoder = protocol.ProtocolThreeDecoder(response_handler)
2517
response_handler.setProtoAndMediumRequest(decoder, StubRequest())
2518
return decoder, response_handler
2520
def test_trivial_response_decoding(self):
2521
"""Smoke test for the simplest possible v3 response: empty headers,
2522
status byte, empty args, no body.
2524
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2525
response_status = 'oS' # success
2526
args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
2527
end = 'e' # end marker
2528
message_bytes = headers + response_status + args + end
2529
decoder, response_handler = self.make_logging_response_decoder()
2530
decoder.accept_bytes(message_bytes)
2531
# The protocol decoder has finished, and consumed all bytes
2532
self.assertEqual(0, decoder.next_read_size())
2533
self.assertEqual('', decoder.unused_data)
2534
# The message handler has been invoked with all the parts of the
2535
# trivial response: empty headers, status byte, no args, end.
2537
[('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
2538
response_handler.event_log)
2540
def test_incomplete_message(self):
2541
"""A decoder will keep signalling that it needs more bytes via
2542
next_read_size() != 0 until it has seen a complete message, regardless
2543
which state it is in.
2545
# Define a simple response that uses all possible message parts.
2546
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2547
response_status = 'oS' # success
2548
args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
2549
body = 'b\0\0\0\x04BODY' # a body: 'BODY'
2550
end = 'e' # end marker
2551
simple_response = headers + response_status + args + body + end
2552
# Feed the request to the decoder one byte at a time.
2553
decoder, response_handler = self.make_logging_response_decoder()
2554
for byte in simple_response:
2555
self.assertNotEqual(0, decoder.next_read_size())
2556
decoder.accept_bytes(byte)
2557
# Now the response is complete
2558
self.assertEqual(0, decoder.next_read_size())
2560
def test_read_response_tuple_raises_UnknownSmartMethod(self):
2561
"""read_response_tuple raises UnknownSmartMethod if the server replied
2562
with 'UnknownMethod'.
2564
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2565
response_status = 'oE' # error flag
2566
# args: ('UnknownMethod', 'method-name')
2567
args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
2568
end = 'e' # end marker
2569
message_bytes = headers + response_status + args + end
2570
decoder, response_handler = self.make_conventional_response_decoder()
2571
decoder.accept_bytes(message_bytes)
2572
error = self.assertRaises(
2573
errors.UnknownSmartMethod, response_handler.read_response_tuple)
2574
self.assertEqual('method-name', error.verb)
2576
def test_read_response_tuple_error(self):
2577
"""If the response has an error, it is raised as an exception."""
2578
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2579
response_status = 'oE' # error
2580
args = 's\0\0\0\x1al9:first arg10:second arge' # two args
2581
end = 'e' # end marker
2582
message_bytes = headers + response_status + args + end
2583
decoder, response_handler = self.make_conventional_response_decoder()
2584
decoder.accept_bytes(message_bytes)
2585
error = self.assertRaises(
2586
errors.ErrorFromSmartServer, response_handler.read_response_tuple)
2587
self.assertEqual(('first arg', 'second arg'), error.error_tuple)
2590
class TestClientEncodingProtocolThree(TestSmartProtocol):
2592
request_encoder = protocol.ProtocolThreeRequester
2593
response_decoder = protocol.ProtocolThreeDecoder
2594
server_protocol_class = protocol.ProtocolThreeDecoder
2596
def make_client_encoder_and_output(self):
2597
result = self.make_client_protocol_and_output()
2598
requester, response_handler, output = result
2599
return requester, output
2601
def test_call_smoke_test(self):
2602
"""A smoke test for ProtocolThreeRequester.call.
2604
This test checks that a particular simple invocation of call emits the
2605
correct bytes for that invocation.
2607
requester, output = self.make_client_encoder_and_output()
2608
requester.set_headers({'header name': 'header value'})
2609
requester.call('one arg')
2611
'bzr message 3 (bzr 1.6)\n' # protocol version
2612
'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2613
's\x00\x00\x00\x0bl7:one arge' # args
2617
def test_call_with_body_bytes_smoke_test(self):
2618
"""A smoke test for ProtocolThreeRequester.call_with_body_bytes.
2620
This test checks that a particular simple invocation of
2621
call_with_body_bytes emits the correct bytes for that invocation.
2623
requester, output = self.make_client_encoder_and_output()
2624
requester.set_headers({'header name': 'header value'})
2625
requester.call_with_body_bytes(('one arg',), 'body bytes')
2627
'bzr message 3 (bzr 1.6)\n' # protocol version
2628
'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2629
's\x00\x00\x00\x0bl7:one arge' # args
2630
'b' # there is a prefixed body
2631
'\x00\x00\x00\nbody bytes' # the prefixed body
2635
def test_call_writes_just_once(self):
2636
"""A bodyless request is written to the medium all at once."""
2637
medium_request = StubMediumRequest()
2638
encoder = protocol.ProtocolThreeRequester(medium_request)
2639
encoder.call('arg1', 'arg2', 'arg3')
2641
['accept_bytes', 'finished_writing'], medium_request.calls)
2643
def test_call_with_body_bytes_writes_just_once(self):
2644
"""A request with body bytes is written to the medium all at once."""
2645
medium_request = StubMediumRequest()
2646
encoder = protocol.ProtocolThreeRequester(medium_request)
2647
encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
2649
['accept_bytes', 'finished_writing'], medium_request.calls)
2652
class StubMediumRequest(object):
2653
"""A stub medium request that tracks the number of times accept_bytes is
2659
self._medium = 'dummy medium'
2661
def accept_bytes(self, bytes):
2662
self.calls.append('accept_bytes')
2664
def finished_writing(self):
2665
self.calls.append('finished_writing')
2668
class TestResponseEncodingProtocolThree(tests.TestCase):
2670
def make_response_encoder(self):
2671
out_stream = StringIO()
2672
response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
2673
return response_encoder, out_stream
2675
def test_send_error_unknown_method(self):
2676
encoder, out_stream = self.make_response_encoder()
2677
encoder.send_error(errors.UnknownSmartMethod('method name'))
2678
# Use assertEndsWith so that we don't compare the header, which varies
2679
# by bzrlib.__version__.
2680
self.assertEndsWith(
2681
out_stream.getvalue(),
2684
# tuple: 'UnknownMethod', 'method name'
2685
's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
2690
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
2691
"""Tests for buffering of responses.
2693
We want to avoid doing many small writes when one would do, to avoid
2694
unnecessary network overhead.
2699
self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2701
def assertWriteCount(self, expected_count):
2703
expected_count, len(self.writes),
2704
"Too many writes: %r" % (self.writes,))
2706
def test_send_error_writes_just_once(self):
2707
"""An error response is written to the medium all at once."""
2708
self.responder.send_error(Exception('An exception string.'))
2709
self.assertWriteCount(1)
2711
def test_send_response_writes_just_once(self):
2712
"""A normal response with no body is written to the medium all at once.
2714
response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
2715
self.responder.send_response(response)
2716
self.assertWriteCount(1)
2718
def test_send_response_with_body_writes_just_once(self):
2719
"""A normal response with a monolithic body is written to the medium
2722
response = _mod_request.SuccessfulSmartServerResponse(
2723
('arg', 'arg'), body='body bytes')
2724
self.responder.send_response(response)
2725
self.assertWriteCount(1)
2727
def test_send_response_with_body_stream_writes_once_per_chunk(self):
2728
"""A normal response with a stream body is written to the medium
2729
writes to the medium once per chunk.
2731
# Construct a response with stream with 2 chunks in it.
2732
response = _mod_request.SuccessfulSmartServerResponse(
2733
('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2734
self.responder.send_response(response)
2735
# We will write 3 times: exactly once for each chunk, plus a final
2736
# write to end the response.
2737
self.assertWriteCount(3)
2740
class TestSmartClientUnicode(tests.TestCase):
2741
"""_SmartClient tests for unicode arguments.
2743
Unicode arguments to call_with_body_bytes are not correct (remote method
2744
names, arguments, and bodies must all be expressed as byte strings), but
2745
_SmartClient should gracefully reject them, rather than getting into a
2746
broken state that prevents future correct calls from working. That is, it
2747
should be possible to issue more requests on the medium afterwards, rather
2748
than allowing one bad call to call_with_body_bytes to cause later calls to
2749
mysteriously fail with TooManyConcurrentRequests.
2752
def assertCallDoesNotBreakMedium(self, method, args, body):
2753
"""Call a medium with the given method, args and body, then assert that
2754
the medium is left in a sane state, i.e. is capable of allowing further
2757
input = StringIO("\n")
2759
client_medium = medium.SmartSimplePipesClientMedium(
2760
input, output, 'ignored base')
2761
smart_client = client._SmartClient(client_medium)
2762
self.assertRaises(TypeError,
2763
smart_client.call_with_body_bytes, method, args, body)
2764
self.assertEqual("", output.getvalue())
2765
self.assertEqual(None, client_medium._current_request)
2767
def test_call_with_body_bytes_unicode_method(self):
2768
self.assertCallDoesNotBreakMedium(u'method', ('args',), 'body')
2770
def test_call_with_body_bytes_unicode_args(self):
2771
self.assertCallDoesNotBreakMedium('method', (u'args',), 'body')
2772
self.assertCallDoesNotBreakMedium('method', ('arg1', u'arg2'), 'body')
2774
def test_call_with_body_bytes_unicode_body(self):
2775
self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
2778
class MockMedium(medium.SmartClientMedium):
2779
"""A mock medium that can be used to test _SmartClient.
2781
It can be given a series of requests to expect (and responses it should
2782
return for them). It can also be told when the client is expected to
2783
disconnect a medium. Expectations must be satisfied in the order they are
2784
given, or else an AssertionError will be raised.
2786
Typical use looks like::
2788
medium = MockMedium()
2789
medium.expect_request(...)
2790
medium.expect_request(...)
2791
medium.expect_request(...)
2795
super(MockMedium, self).__init__('dummy base')
2796
self._mock_request = _MockMediumRequest(self)
2797
self._expected_events = []
2799
def expect_request(self, request_bytes, response_bytes,
2800
allow_partial_read=False):
2801
"""Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
2803
No assumption is made about how many times accept_bytes should be
2804
called to send the request. Similarly, no assumption is made about how
2805
many times read_bytes/read_line are called by protocol code to read a
2808
request.accept_bytes('ab')
2809
request.accept_bytes('cd')
2810
request.finished_writing()
2814
request.accept_bytes('abcd')
2815
request.finished_writing()
2817
Will both satisfy ``medium.expect_request('abcd', ...)``. Thus tests
2818
using this should not break due to irrelevant changes in protocol
2821
:param allow_partial_read: if True, no assertion is raised if a
2822
response is not fully read. Setting this is useful when the client
2823
is expected to disconnect without needing to read the complete
2824
response. Default is False.
2826
self._expected_events.append(('send request', request_bytes))
2827
if allow_partial_read:
2828
self._expected_events.append(
2829
('read response (partial)', response_bytes))
2831
self._expected_events.append(('read response', response_bytes))
2833
def expect_disconnect(self):
2834
"""Expect the client to call ``medium.disconnect()``."""
2835
self._expected_events.append('disconnect')
2837
def _assertEvent(self, observed_event):
2838
"""Raise AssertionError unless observed_event matches the next expected
2841
:seealso: expect_request
2842
:seealso: expect_disconnect
2845
expected_event = self._expected_events.pop(0)
2847
raise AssertionError(
2848
'Mock medium observed event %r, but no more events expected'
2849
% (observed_event,))
2850
if expected_event[0] == 'read response (partial)':
2851
if observed_event[0] != 'read response':
2852
raise AssertionError(
2853
'Mock medium observed event %r, but expected event %r'
2854
% (observed_event, expected_event))
2855
elif observed_event != expected_event:
2856
raise AssertionError(
2857
'Mock medium observed event %r, but expected event %r'
2858
% (observed_event, expected_event))
2859
if self._expected_events:
2860
next_event = self._expected_events[0]
2861
if next_event[0].startswith('read response'):
2862
self._mock_request._response = next_event[1]
2864
def get_request(self):
2865
return self._mock_request
2867
def disconnect(self):
2868
if self._mock_request._read_bytes:
2869
self._assertEvent(('read response', self._mock_request._read_bytes))
2870
self._mock_request._read_bytes = ''
2871
self._assertEvent('disconnect')
2874
class _MockMediumRequest(object):
2875
"""A mock ClientMediumRequest used by MockMedium."""
2877
def __init__(self, mock_medium):
2878
self._medium = mock_medium
2879
self._written_bytes = ''
2880
self._read_bytes = ''
2881
self._response = None
2883
def accept_bytes(self, bytes):
2884
self._written_bytes += bytes
2886
def finished_writing(self):
2887
self._medium._assertEvent(('send request', self._written_bytes))
2888
self._written_bytes = ''
2890
def finished_reading(self):
2891
self._medium._assertEvent(('read response', self._read_bytes))
2892
self._read_bytes = ''
2894
def read_bytes(self, size):
2895
resp = self._response
2896
bytes, resp = resp[:size], resp[size:]
2897
self._response = resp
2898
self._read_bytes += bytes
2901
def read_line(self):
2902
resp = self._response
2904
line, resp = resp.split('\n', 1)
2907
line, resp = resp, ''
2908
self._response = resp
2909
self._read_bytes += line
2913
class Test_SmartClientVersionDetection(tests.TestCase):
2914
"""Tests for _SmartClient's automatic protocol version detection.
2916
On the first remote call, _SmartClient will keep retrying the request with
2917
different protocol versions until it finds one that works.
2920
def test_version_three_server(self):
2921
"""With a protocol 3 server, only one request is needed."""
2922
medium = MockMedium()
2923
smart_client = client._SmartClient(medium, headers={})
2924
message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
2925
medium.expect_request(
2927
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2928
message_start + 's\0\0\0\x13l14:response valueee')
2929
result = smart_client.call('method-name', 'arg 1', 'arg 2')
2930
# The call succeeded without raising any exceptions from the mock
2931
# medium, and the smart_client returns the response from the server.
2932
self.assertEqual(('response value',), result)
2933
self.assertEqual([], medium._expected_events)
2934
# Also, the v3 works then the server should be assumed to support RPCs
2935
# introduced in 1.6.
2936
self.assertFalse(medium._is_remote_before((1, 6)))
2938
def test_version_two_server(self):
2939
"""If the server only speaks protocol 2, the client will first try
2940
version 3, then fallback to protocol 2.
2942
Further, _SmartClient caches the detection, so future requests will all
2943
use protocol 2 immediately.
2945
medium = MockMedium()
2946
smart_client = client._SmartClient(medium, headers={})
2947
# First the client should send a v3 request, but the server will reply
2949
medium.expect_request(
2950
'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
2951
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2952
'bzr response 2\nfailed\n\n')
2953
# So then the client should disconnect to reset the connection, because
2954
# the client needs to assume the server cannot read any further
2955
# requests off the original connection.
2956
medium.expect_disconnect()
2957
# The client should then retry the original request in v2
2958
medium.expect_request(
2959
'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
2960
'bzr response 2\nsuccess\nresponse value\n')
2961
result = smart_client.call('method-name', 'arg 1', 'arg 2')
2962
# The smart_client object will return the result of the successful
2964
self.assertEqual(('response value',), result)
2966
# Now try another request, and this time the client will just use
2967
# protocol 2. (i.e. the autodetection won't be repeated)
2968
medium.expect_request(
2969
'bzr request 2\nanother-method\n',
2970
'bzr response 2\nsuccess\nanother response\n')
2971
result = smart_client.call('another-method')
2972
self.assertEqual(('another response',), result)
2973
self.assertEqual([], medium._expected_events)
2975
# Also, because v3 is not supported, the client medium should assume
2976
# that RPCs introduced in 1.6 aren't supported either.
2977
self.assertTrue(medium._is_remote_before((1, 6)))
2979
def test_unknown_version(self):
2980
"""If the server does not use any known (or at least supported)
2981
protocol version, a SmartProtocolError is raised.
2983
medium = MockMedium()
2984
smart_client = client._SmartClient(medium, headers={})
2985
unknown_protocol_bytes = 'Unknown protocol!'
2986
# The client will try v3 and v2 before eventually giving up.
2987
medium.expect_request(
2988
'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
2989
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2990
unknown_protocol_bytes)
2991
medium.expect_disconnect()
2992
medium.expect_request(
2993
'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
2994
unknown_protocol_bytes)
2995
medium.expect_disconnect()
2997
errors.SmartProtocolError,
2998
smart_client.call, 'method-name', 'arg 1', 'arg 2')
2999
self.assertEqual([], medium._expected_events)
3001
def test_first_response_is_error(self):
3002
"""If the server replies with an error, then the version detection
3005
This test is very similar to test_version_two_server, but catches a bug
3006
we had in the case where the first reply was an error response.
3008
medium = MockMedium()
3009
smart_client = client._SmartClient(medium, headers={})
3010
message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
3011
# Issue a request that gets an error reply in a non-default protocol
3013
medium.expect_request(
3015
's\x00\x00\x00\x10l11:method-nameee',
3016
'bzr response 2\nfailed\n\n')
3017
medium.expect_disconnect()
3018
medium.expect_request(
3019
'bzr request 2\nmethod-name\n',
3020
'bzr response 2\nfailed\nFooBarError\n')
3021
err = self.assertRaises(
3022
errors.ErrorFromSmartServer,
3023
smart_client.call, 'method-name')
3024
self.assertEqual(('FooBarError',), err.error_tuple)
3025
# Now the medium should have remembered the protocol version, so
3026
# subsequent requests will use the remembered version immediately.
3027
medium.expect_request(
3028
'bzr request 2\nmethod-name\n',
3029
'bzr response 2\nsuccess\nresponse value\n')
3030
result = smart_client.call('method-name')
3031
self.assertEqual(('response value',), result)
3032
self.assertEqual([], medium._expected_events)
3035
class Test_SmartClient(tests.TestCase):
3037
def test_call_default_headers(self):
3038
"""ProtocolThreeRequester.call by default sends a 'Software
3041
smart_client = client._SmartClient('dummy medium')
3043
bzrlib.__version__, smart_client._headers['Software version'])
3044
# XXX: need a test that smart_client._headers is passed to the request
3048
class LengthPrefixedBodyDecoder(tests.TestCase):
3050
# XXX: TODO: make accept_reading_trailer invoke translate_response or
3051
# something similar to the ProtocolBase method.
3053
def test_construct(self):
3054
decoder = protocol.LengthPrefixedBodyDecoder()
3055
self.assertFalse(decoder.finished_reading)
3056
self.assertEqual(6, decoder.next_read_size())
3057
self.assertEqual('', decoder.read_pending_data())
3058
self.assertEqual('', decoder.unused_data)
3060
def test_accept_bytes(self):
3061
decoder = protocol.LengthPrefixedBodyDecoder()
3062
decoder.accept_bytes('')
3063
self.assertFalse(decoder.finished_reading)
3064
self.assertEqual(6, decoder.next_read_size())
3065
self.assertEqual('', decoder.read_pending_data())
3066
self.assertEqual('', decoder.unused_data)
3067
decoder.accept_bytes('7')
3068
self.assertFalse(decoder.finished_reading)
3069
self.assertEqual(6, decoder.next_read_size())
3070
self.assertEqual('', decoder.read_pending_data())
3071
self.assertEqual('', decoder.unused_data)
3072
decoder.accept_bytes('\na')
3073
self.assertFalse(decoder.finished_reading)
3074
self.assertEqual(11, decoder.next_read_size())
3075
self.assertEqual('a', decoder.read_pending_data())
3076
self.assertEqual('', decoder.unused_data)
3077
decoder.accept_bytes('bcdefgd')
3078
self.assertFalse(decoder.finished_reading)
3079
self.assertEqual(4, decoder.next_read_size())
3080
self.assertEqual('bcdefg', decoder.read_pending_data())
3081
self.assertEqual('', decoder.unused_data)
3082
decoder.accept_bytes('one')
3083
self.assertFalse(decoder.finished_reading)
3084
self.assertEqual(1, decoder.next_read_size())
3085
self.assertEqual('', decoder.read_pending_data())
3086
self.assertEqual('', decoder.unused_data)
3087
decoder.accept_bytes('\nblarg')
3088
self.assertTrue(decoder.finished_reading)
3089
self.assertEqual(1, decoder.next_read_size())
3090
self.assertEqual('', decoder.read_pending_data())
3091
self.assertEqual('blarg', decoder.unused_data)
3093
def test_accept_bytes_all_at_once_with_excess(self):
3094
decoder = protocol.LengthPrefixedBodyDecoder()
3095
decoder.accept_bytes('1\nadone\nunused')
3096
self.assertTrue(decoder.finished_reading)
3097
self.assertEqual(1, decoder.next_read_size())
3098
self.assertEqual('a', decoder.read_pending_data())
3099
self.assertEqual('unused', decoder.unused_data)
3101
def test_accept_bytes_exact_end_of_body(self):
3102
decoder = protocol.LengthPrefixedBodyDecoder()
3103
decoder.accept_bytes('1\na')
3104
self.assertFalse(decoder.finished_reading)
3105
self.assertEqual(5, decoder.next_read_size())
3106
self.assertEqual('a', decoder.read_pending_data())
3107
self.assertEqual('', decoder.unused_data)
3108
decoder.accept_bytes('done\n')
3109
self.assertTrue(decoder.finished_reading)
3110
self.assertEqual(1, decoder.next_read_size())
3111
self.assertEqual('', decoder.read_pending_data())
3112
self.assertEqual('', decoder.unused_data)
3115
class TestChunkedBodyDecoder(tests.TestCase):
3116
"""Tests for ChunkedBodyDecoder.
3118
This is the body decoder used for protocol version two.
3121
def test_construct(self):
3122
decoder = protocol.ChunkedBodyDecoder()
3123
self.assertFalse(decoder.finished_reading)
3124
self.assertEqual(8, decoder.next_read_size())
3125
self.assertEqual(None, decoder.read_next_chunk())
3126
self.assertEqual('', decoder.unused_data)
3128
def test_empty_content(self):
3129
"""'chunked\nEND\n' is the complete encoding of a zero-length body.
3131
decoder = protocol.ChunkedBodyDecoder()
3132
decoder.accept_bytes('chunked\n')
3133
decoder.accept_bytes('END\n')
3134
self.assertTrue(decoder.finished_reading)
3135
self.assertEqual(None, decoder.read_next_chunk())
3136
self.assertEqual('', decoder.unused_data)
3138
def test_one_chunk(self):
3139
"""A body in a single chunk is decoded correctly."""
3140
decoder = protocol.ChunkedBodyDecoder()
3141
decoder.accept_bytes('chunked\n')
3142
chunk_length = 'f\n'
3143
chunk_content = '123456789abcdef'
3145
decoder.accept_bytes(chunk_length + chunk_content + finish)
3146
self.assertTrue(decoder.finished_reading)
3147
self.assertEqual(chunk_content, decoder.read_next_chunk())
3148
self.assertEqual('', decoder.unused_data)
3150
def test_incomplete_chunk(self):
3151
"""When there are less bytes in the chunk than declared by the length,
3152
then we haven't finished reading yet.
3154
decoder = protocol.ChunkedBodyDecoder()
3155
decoder.accept_bytes('chunked\n')
3156
chunk_length = '8\n'
3158
decoder.accept_bytes(chunk_length + three_bytes)
3159
self.assertFalse(decoder.finished_reading)
3161
5 + 4, decoder.next_read_size(),
3162
"The next_read_size hint should be the number of missing bytes in "
3163
"this chunk plus 4 (the length of the end-of-body marker: "
3165
self.assertEqual(None, decoder.read_next_chunk())
3167
def test_incomplete_length(self):
3168
"""A chunk length hasn't been read until a newline byte has been read.
3170
decoder = protocol.ChunkedBodyDecoder()
3171
decoder.accept_bytes('chunked\n')
3172
decoder.accept_bytes('9')
3174
1, decoder.next_read_size(),
3175
"The next_read_size hint should be 1, because we don't know the "
3177
decoder.accept_bytes('\n')
3179
9 + 4, decoder.next_read_size(),
3180
"The next_read_size hint should be the length of the chunk plus 4 "
3181
"(the length of the end-of-body marker: 'END\\n')")
3182
self.assertFalse(decoder.finished_reading)
3183
self.assertEqual(None, decoder.read_next_chunk())
3185
def test_two_chunks(self):
3186
"""Content from multiple chunks is concatenated."""
3187
decoder = protocol.ChunkedBodyDecoder()
3188
decoder.accept_bytes('chunked\n')
3189
chunk_one = '3\naaa'
3190
chunk_two = '5\nbbbbb'
3192
decoder.accept_bytes(chunk_one + chunk_two + finish)
3193
self.assertTrue(decoder.finished_reading)
3194
self.assertEqual('aaa', decoder.read_next_chunk())
3195
self.assertEqual('bbbbb', decoder.read_next_chunk())
3196
self.assertEqual(None, decoder.read_next_chunk())
3197
self.assertEqual('', decoder.unused_data)
3199
def test_excess_bytes(self):
3200
"""Bytes after the chunked body are reported as unused bytes."""
3201
decoder = protocol.ChunkedBodyDecoder()
3202
decoder.accept_bytes('chunked\n')
3203
chunked_body = "5\naaaaaEND\n"
3204
excess_bytes = "excess bytes"
3205
decoder.accept_bytes(chunked_body + excess_bytes)
3206
self.assertTrue(decoder.finished_reading)
3207
self.assertEqual('aaaaa', decoder.read_next_chunk())
3208
self.assertEqual(excess_bytes, decoder.unused_data)
3210
1, decoder.next_read_size(),
3211
"next_read_size hint should be 1 when finished_reading.")
3213
def test_multidigit_length(self):
3214
"""Lengths in the chunk prefixes can have multiple digits."""
3215
decoder = protocol.ChunkedBodyDecoder()
3216
decoder.accept_bytes('chunked\n')
3218
chunk_prefix = hex(length) + '\n'
3219
chunk_bytes = 'z' * length
3221
decoder.accept_bytes(chunk_prefix + chunk_bytes + finish)
3222
self.assertTrue(decoder.finished_reading)
3223
self.assertEqual(chunk_bytes, decoder.read_next_chunk())
3225
def test_byte_at_a_time(self):
3226
"""A complete body fed to the decoder one byte at a time should not
3227
confuse the decoder. That is, it should give the same result as if the
3228
bytes had been received in one batch.
3230
This test is the same as test_one_chunk apart from the way accept_bytes
3233
decoder = protocol.ChunkedBodyDecoder()
3234
decoder.accept_bytes('chunked\n')
3235
chunk_length = 'f\n'
3236
chunk_content = '123456789abcdef'
3238
for byte in (chunk_length + chunk_content + finish):
3239
decoder.accept_bytes(byte)
3240
self.assertTrue(decoder.finished_reading)
3241
self.assertEqual(chunk_content, decoder.read_next_chunk())
3242
self.assertEqual('', decoder.unused_data)
3244
def test_read_pending_data_resets(self):
3245
"""read_pending_data does not return the same bytes twice."""
3246
decoder = protocol.ChunkedBodyDecoder()
3247
decoder.accept_bytes('chunked\n')
3248
chunk_one = '3\naaa'
3249
chunk_two = '3\nbbb'
3251
decoder.accept_bytes(chunk_one)
3252
self.assertEqual('aaa', decoder.read_next_chunk())
3253
decoder.accept_bytes(chunk_two)
3254
self.assertEqual('bbb', decoder.read_next_chunk())
3255
self.assertEqual(None, decoder.read_next_chunk())
3257
def test_decode_error(self):
3258
decoder = protocol.ChunkedBodyDecoder()
3259
decoder.accept_bytes('chunked\n')
3260
chunk_one = 'b\nfirst chunk'
3261
error_signal = 'ERR\n'
3262
error_chunks = '5\npart1' + '5\npart2'
3264
decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish)
3265
self.assertTrue(decoder.finished_reading)
3266
self.assertEqual('first chunk', decoder.read_next_chunk())
3267
expected_failure = _mod_request.FailedSmartServerResponse(
3269
self.assertEqual(expected_failure, decoder.read_next_chunk())
3271
def test_bad_header(self):
3272
"""accept_bytes raises a SmartProtocolError if a chunked body does not
3273
start with the right header.
3275
decoder = protocol.ChunkedBodyDecoder()
3277
errors.SmartProtocolError, decoder.accept_bytes, 'bad header\n')
3280
class TestSuccessfulSmartServerResponse(tests.TestCase):
3282
def test_construct_no_body(self):
3283
response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'))
3284
self.assertEqual(('foo', 'bar'), response.args)
3285
self.assertEqual(None, response.body)
3287
def test_construct_with_body(self):
3288
response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'),
3290
self.assertEqual(('foo', 'bar'), response.args)
3291
self.assertEqual('bytes', response.body)
3292
# repr(response) doesn't trigger exceptions.
3295
def test_construct_with_body_stream(self):
3296
bytes_iterable = ['abc']
3297
response = _mod_request.SuccessfulSmartServerResponse(
3298
('foo', 'bar'), body_stream=bytes_iterable)
3299
self.assertEqual(('foo', 'bar'), response.args)
3300
self.assertEqual(bytes_iterable, response.body_stream)
3302
def test_construct_rejects_body_and_body_stream(self):
3303
"""'body' and 'body_stream' are mutually exclusive."""
3306
_mod_request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
3308
def test_is_successful(self):
3309
"""is_successful should return True for SuccessfulSmartServerResponse."""
3310
response = _mod_request.SuccessfulSmartServerResponse(('error',))
3311
self.assertEqual(True, response.is_successful())
3314
class TestFailedSmartServerResponse(tests.TestCase):
3316
def test_construct(self):
3317
response = _mod_request.FailedSmartServerResponse(('foo', 'bar'))
3318
self.assertEqual(('foo', 'bar'), response.args)
3319
self.assertEqual(None, response.body)
3320
response = _mod_request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
3321
self.assertEqual(('foo', 'bar'), response.args)
3322
self.assertEqual('bytes', response.body)
3323
# repr(response) doesn't trigger exceptions.
3326
def test_is_successful(self):
3327
"""is_successful should return False for FailedSmartServerResponse."""
3328
response = _mod_request.FailedSmartServerResponse(('error',))
3329
self.assertEqual(False, response.is_successful())
3332
class FakeHTTPMedium(object):
3334
self.written_request = None
3335
self._current_request = None
3336
def send_http_smart_request(self, bytes):
3337
self.written_request = bytes
3341
class HTTPTunnellingSmokeTest(tests.TestCase):
3344
super(HTTPTunnellingSmokeTest, self).setUp()
3345
# We use the VFS layer as part of HTTP tunnelling tests.
3346
self._captureVar('BZR_NO_SMART_VFS', None)
3348
def test_smart_http_medium_request_accept_bytes(self):
3349
medium = FakeHTTPMedium()
3350
request = SmartClientHTTPMediumRequest(medium)
3351
request.accept_bytes('abc')
3352
request.accept_bytes('def')
3353
self.assertEqual(None, medium.written_request)
3354
request.finished_writing()
3355
self.assertEqual('abcdef', medium.written_request)
3358
class RemoteHTTPTransportTestCase(tests.TestCase):
3360
def test_remote_path_after_clone_child(self):
3361
# If a user enters "bzr+http://host/foo", we want to sent all smart
3362
# requests for child URLs of that to the original URL. i.e., we want to
3363
# POST to "bzr+http://host/foo/.bzr/smart" and never something like
3364
# "bzr+http://host/foo/.bzr/branch/.bzr/smart". So, a cloned
3365
# RemoteHTTPTransport remembers the initial URL, and adjusts the relpaths
3366
# it sends in smart requests accordingly.
3367
base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
3368
new_transport = base_transport.clone('child_dir')
3369
self.assertEqual(base_transport._http_transport,
3370
new_transport._http_transport)
3371
self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
3374
new_transport._client.remote_path_from_transport(new_transport))
3376
def test_remote_path_unnormal_base(self):
3377
# If the transport's base isn't normalised, the _remote_path should
3378
# still be calculated correctly.
3379
base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3380
self.assertEqual('c', base_transport._remote_path('c'))
3382
def test_clone_unnormal_base(self):
3383
# If the transport's base isn't normalised, cloned transports should
3384
# still work correctly.
3385
base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3386
new_transport = base_transport.clone('c')
3387
self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
3390
new_transport._client.remote_path_from_transport(new_transport))
3393
# TODO: Client feature that does get_bundle and then installs that into a
3394
# branch; this can be used in place of the regular pull/fetch operation when
3395
# coming from a smart server.
3397
# TODO: Eventually, want to do a 'branch' command by fetching the whole
3398
# history as one big bundle. How?
3400
# The branch command does 'br_from.sprout', which tries to preserve the same
3401
# format. We don't necessarily even want that.
3403
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
3404
# operation from one branch into another. It already has some code for
3405
# pulling from a bundle, which it does by trying to see if the destination is
3406
# a bundle file. So it seems the logic for pull ought to be:
3408
# - if it's a smart server, get a bundle from there and install that
3409
# - if it's a bundle, install that
3410
# - if it's a branch, pull from there
3412
# Getting a bundle from a smart server is a bit different from reading a
3413
# bundle from a URL:
3415
# - we can reasonably remember the URL we last read from
3416
# - you can specify a revision number to pull, and we need to pass it across
3417
# to the server as a limit on what will be requested
3419
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
3420
# otherwise whether it's a bundle?) Should this be a property or method of
3421
# the transport? For the ssh protocol, we always know it's a smart server.
3422
# For http, we potentially need to probe. But if we're explicitly given
3423
# bzr+http:// then we can skip that for now.