1
# Copyright (C) 2006, 2007 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
"""Tests for smart transport"""
19
# all of this deals with byte strings so this is safe
20
from cStringIO import StringIO
33
from bzrlib.smart import (
38
request as _mod_request,
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
43
from bzrlib.transport import (
49
from bzrlib.transport.http import SmartClientHTTPMediumRequest
52
class StringIOSSHVendor(object):
53
"""A SSH vendor that uses StringIO to buffer writes and answer reads."""
55
def __init__(self, read_from, write_to):
56
self.read_from = read_from
57
self.write_to = write_to
60
def connect_ssh(self, username, password, host, port, command):
61
self.calls.append(('connect_ssh', username, password, host, port,
63
return StringIOSSHConnection(self)
66
class StringIOSSHConnection(object):
67
"""A SSH connection that uses StringIO to buffer writes and answer reads."""
69
def __init__(self, vendor):
73
self.vendor.calls.append(('close', ))
75
def get_filelike_channels(self):
76
return self.vendor.read_from, self.vendor.write_to
79
class _InvalidHostnameFeature(tests.Feature):
80
"""Does 'non_existent.invalid' fail to resolve?
82
RFC 2606 states that .invalid is reserved for invalid domain names, and
83
also underscores are not a valid character in domain names. Despite this,
84
it's possible a badly misconfigured name server might decide to always
85
return an address for any name, so this feature allows us to distinguish a
86
broken system from a broken test.
91
socket.gethostbyname('non_existent.invalid')
92
except socket.gaierror:
93
# The host name failed to resolve. Good.
98
def feature_name(self):
99
return 'invalid hostname'
101
InvalidHostnameFeature = _InvalidHostnameFeature()
104
class SmartClientMediumTests(tests.TestCase):
105
"""Tests for SmartClientMedium.
107
We should create a test scenario for this: we need a server module that
108
construct the test-servers (like make_loopsocket_and_medium), and the list
109
of SmartClientMedium classes to test.
112
def make_loopsocket_and_medium(self):
113
"""Create a loopback socket for testing, and a medium aimed at it."""
114
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
115
sock.bind(('127.0.0.1', 0))
117
port = sock.getsockname()[1]
118
client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
119
return sock, client_medium
121
def receive_bytes_on_server(self, sock, bytes):
122
"""Accept a connection on sock and read 3 bytes.
124
The bytes are appended to the list bytes.
126
:return: a Thread which is running to do the accept and recv.
128
def _receive_bytes_on_server():
129
connection, address = sock.accept()
130
bytes.append(osutils.recv_all(connection, 3))
132
t = threading.Thread(target=_receive_bytes_on_server)
136
def test_construct_smart_simple_pipes_client_medium(self):
137
# the SimplePipes client medium takes two pipes:
138
# readable pipe, writeable pipe.
139
# Constructing one should just save these and do nothing.
140
# We test this by passing in None.
141
client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
143
def test_simple_pipes_client_request_type(self):
144
# SimplePipesClient should use SmartClientStreamMediumRequest's.
145
client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
146
request = client_medium.get_request()
147
self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
149
def test_simple_pipes_client_get_concurrent_requests(self):
150
# the simple_pipes client does not support pipelined requests:
151
# but it does support serial requests: we construct one after
152
# another is finished. This is a smoke test testing the integration
153
# of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
154
# classes - as the sibling classes share this logic, they do not have
155
# explicit tests for this.
157
client_medium = medium.SmartSimplePipesClientMedium(
158
None, output, 'base')
159
request = client_medium.get_request()
160
request.finished_writing()
161
request.finished_reading()
162
request2 = client_medium.get_request()
163
request2.finished_writing()
164
request2.finished_reading()
166
def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
167
# accept_bytes writes to the writeable pipe.
169
client_medium = medium.SmartSimplePipesClientMedium(
170
None, output, 'base')
171
client_medium._accept_bytes('abc')
172
self.assertEqual('abc', output.getvalue())
174
def test_simple_pipes_client_disconnect_does_nothing(self):
175
# calling disconnect does nothing.
178
client_medium = medium.SmartSimplePipesClientMedium(
179
input, output, 'base')
180
# send some bytes to ensure disconnecting after activity still does not
182
client_medium._accept_bytes('abc')
183
client_medium.disconnect()
184
self.assertFalse(input.closed)
185
self.assertFalse(output.closed)
187
def test_simple_pipes_client_accept_bytes_after_disconnect(self):
188
# calling disconnect on the client does not alter the pipe that
189
# accept_bytes writes to.
192
client_medium = medium.SmartSimplePipesClientMedium(
193
input, output, 'base')
194
client_medium._accept_bytes('abc')
195
client_medium.disconnect()
196
client_medium._accept_bytes('abc')
197
self.assertFalse(input.closed)
198
self.assertFalse(output.closed)
199
self.assertEqual('abcabc', output.getvalue())
201
def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
202
# Doing a disconnect on a new (and thus unconnected) SimplePipes medium
204
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
205
client_medium.disconnect()
207
def test_simple_pipes_client_can_always_read(self):
208
# SmartSimplePipesClientMedium is never disconnected, so read_bytes
209
# always tries to read from the underlying pipe.
210
input = StringIO('abcdef')
211
client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
212
self.assertEqual('abc', client_medium.read_bytes(3))
213
client_medium.disconnect()
214
self.assertEqual('def', client_medium.read_bytes(3))
216
def test_simple_pipes_client_supports__flush(self):
217
# invoking _flush on a SimplePipesClient should flush the output
218
# pipe. We test this by creating an output pipe that records
219
# flush calls made to it.
220
from StringIO import StringIO # get regular StringIO
224
def logging_flush(): flush_calls.append('flush')
225
output.flush = logging_flush
226
client_medium = medium.SmartSimplePipesClientMedium(
227
input, output, 'base')
228
# this call is here to ensure we only flush once, not on every
229
# _accept_bytes call.
230
client_medium._accept_bytes('abc')
231
client_medium._flush()
232
client_medium.disconnect()
233
self.assertEqual(['flush'], flush_calls)
235
def test_construct_smart_ssh_client_medium(self):
236
# the SSH client medium takes:
237
# host, port, username, password, vendor
238
# Constructing one should just save these and do nothing.
239
# we test this by creating a empty bound socket and constructing
241
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
242
sock.bind(('127.0.0.1', 0))
243
unopened_port = sock.getsockname()[1]
244
# having vendor be invalid means that if it tries to connect via the
245
# vendor it will blow up.
246
client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
247
username=None, password=None, base='base', vendor="not a vendor",
248
bzr_remote_path='bzr')
251
def test_ssh_client_connects_on_first_use(self):
252
# The only thing that initiates a connection from the medium is giving
255
vendor = StringIOSSHVendor(StringIO(), output)
256
client_medium = medium.SmartSSHClientMedium(
257
'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
259
client_medium._accept_bytes('abc')
260
self.assertEqual('abc', output.getvalue())
261
self.assertEqual([('connect_ssh', 'a username', 'a password',
262
'a hostname', 'a port',
263
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
266
def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
267
# The only thing that initiates a connection from the medium is giving
270
vendor = StringIOSSHVendor(StringIO(), output)
271
orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
272
def cleanup_environ():
273
osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
274
self.addCleanup(cleanup_environ)
275
os.environ['BZR_REMOTE_PATH'] = 'fugly'
276
client_medium = self.callDeprecated(
277
['bzr_remote_path is required as of bzr 0.92'],
278
medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
279
'a password', 'base', vendor)
280
client_medium._accept_bytes('abc')
281
self.assertEqual('abc', output.getvalue())
282
self.assertEqual([('connect_ssh', 'a username', 'a password',
283
'a hostname', 'a port',
284
['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
287
def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
288
# The only thing that initiates a connection from the medium is giving
291
vendor = StringIOSSHVendor(StringIO(), output)
292
client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
293
'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
294
client_medium._accept_bytes('abc')
295
self.assertEqual('abc', output.getvalue())
296
self.assertEqual([('connect_ssh', 'a username', 'a password',
297
'a hostname', 'a port',
298
['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
301
def test_ssh_client_disconnect_does_so(self):
302
# calling disconnect should disconnect both the read_from and write_to
303
# file-like object it from the ssh connection.
306
vendor = StringIOSSHVendor(input, output)
307
client_medium = medium.SmartSSHClientMedium(
308
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
309
client_medium._accept_bytes('abc')
310
client_medium.disconnect()
311
self.assertTrue(input.closed)
312
self.assertTrue(output.closed)
314
('connect_ssh', None, None, 'a hostname', None,
315
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
320
def test_ssh_client_disconnect_allows_reconnection(self):
321
# calling disconnect on the client terminates the connection, but should
322
# not prevent additional connections occuring.
323
# we test this by initiating a second connection after doing a
327
vendor = StringIOSSHVendor(input, output)
328
client_medium = medium.SmartSSHClientMedium(
329
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
330
client_medium._accept_bytes('abc')
331
client_medium.disconnect()
332
# the disconnect has closed output, so we need a new output for the
333
# new connection to write to.
336
vendor.read_from = input2
337
vendor.write_to = output2
338
client_medium._accept_bytes('abc')
339
client_medium.disconnect()
340
self.assertTrue(input.closed)
341
self.assertTrue(output.closed)
342
self.assertTrue(input2.closed)
343
self.assertTrue(output2.closed)
345
('connect_ssh', None, None, 'a hostname', None,
346
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
348
('connect_ssh', None, None, 'a hostname', None,
349
['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
354
def test_ssh_client_ignores_disconnect_when_not_connected(self):
355
# Doing a disconnect on a new (and thus unconnected) SSH medium
356
# does not fail. It's ok to disconnect an unconnected medium.
357
client_medium = medium.SmartSSHClientMedium(
358
None, base='base', bzr_remote_path='bzr')
359
client_medium.disconnect()
361
def test_ssh_client_raises_on_read_when_not_connected(self):
362
# Doing a read on a new (and thus unconnected) SSH medium raises
363
# MediumNotConnected.
364
client_medium = medium.SmartSSHClientMedium(
365
None, base='base', bzr_remote_path='bzr')
366
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
368
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
371
def test_ssh_client_supports__flush(self):
372
# invoking _flush on a SSHClientMedium should flush the output
373
# pipe. We test this by creating an output pipe that records
374
# flush calls made to it.
375
from StringIO import StringIO # get regular StringIO
379
def logging_flush(): flush_calls.append('flush')
380
output.flush = logging_flush
381
vendor = StringIOSSHVendor(input, output)
382
client_medium = medium.SmartSSHClientMedium(
383
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
384
# this call is here to ensure we only flush once, not on every
385
# _accept_bytes call.
386
client_medium._accept_bytes('abc')
387
client_medium._flush()
388
client_medium.disconnect()
389
self.assertEqual(['flush'], flush_calls)
391
def test_construct_smart_tcp_client_medium(self):
392
# the TCP client medium takes a host and a port. Constructing it won't
393
# connect to anything.
394
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
395
sock.bind(('127.0.0.1', 0))
396
unopened_port = sock.getsockname()[1]
397
client_medium = medium.SmartTCPClientMedium(
398
'127.0.0.1', unopened_port, 'base')
401
def test_tcp_client_connects_on_first_use(self):
402
# The only thing that initiates a connection from the medium is giving
404
sock, medium = self.make_loopsocket_and_medium()
406
t = self.receive_bytes_on_server(sock, bytes)
407
medium.accept_bytes('abc')
410
self.assertEqual(['abc'], bytes)
412
def test_tcp_client_disconnect_does_so(self):
413
# calling disconnect on the client terminates the connection.
414
# we test this by forcing a short read during a socket.MSG_WAITALL
415
# call: write 2 bytes, try to read 3, and then the client disconnects.
416
sock, medium = self.make_loopsocket_and_medium()
418
t = self.receive_bytes_on_server(sock, bytes)
419
medium.accept_bytes('ab')
423
self.assertEqual(['ab'], bytes)
424
# now disconnect again: this should not do anything, if disconnection
425
# really did disconnect.
429
def test_tcp_client_ignores_disconnect_when_not_connected(self):
430
# Doing a disconnect on a new (and thus unconnected) TCP medium
431
# does not fail. It's ok to disconnect an unconnected medium.
432
client_medium = medium.SmartTCPClientMedium(None, None, None)
433
client_medium.disconnect()
435
def test_tcp_client_raises_on_read_when_not_connected(self):
436
# Doing a read on a new (and thus unconnected) TCP medium raises
437
# MediumNotConnected.
438
client_medium = medium.SmartTCPClientMedium(None, None, None)
439
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
440
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
442
def test_tcp_client_supports__flush(self):
443
# invoking _flush on a TCPClientMedium should do something useful.
444
# RBC 20060922 not sure how to test/tell in this case.
445
sock, medium = self.make_loopsocket_and_medium()
447
t = self.receive_bytes_on_server(sock, bytes)
448
# try with nothing buffered
450
medium._accept_bytes('ab')
451
# and with something sent.
456
self.assertEqual(['ab'], bytes)
457
# now disconnect again : this should not do anything, if disconnection
458
# really did disconnect.
461
def test_tcp_client_host_unknown_connection_error(self):
462
self.requireFeature(InvalidHostnameFeature)
463
client_medium = medium.SmartTCPClientMedium(
464
'non_existent.invalid', 4155, 'base')
466
errors.ConnectionError, client_medium._ensure_connection)
469
class TestSmartClientStreamMediumRequest(tests.TestCase):
470
"""Tests the for SmartClientStreamMediumRequest.
472
SmartClientStreamMediumRequest is a helper for the three stream based
473
mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
474
those three mediums implement the interface it expects.
477
def test_accept_bytes_after_finished_writing_errors(self):
478
# calling accept_bytes after calling finished_writing raises
479
# WritingCompleted to prevent bad assumptions on stream environments
480
# breaking the needs of message-based environments.
482
client_medium = medium.SmartSimplePipesClientMedium(
483
None, output, 'base')
484
request = medium.SmartClientStreamMediumRequest(client_medium)
485
request.finished_writing()
486
self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
488
def test_accept_bytes(self):
489
# accept bytes should invoke _accept_bytes on the stream medium.
490
# we test this by using the SimplePipes medium - the most trivial one
491
# and checking that the pipes get the data.
494
client_medium = medium.SmartSimplePipesClientMedium(
495
input, output, 'base')
496
request = medium.SmartClientStreamMediumRequest(client_medium)
497
request.accept_bytes('123')
498
request.finished_writing()
499
request.finished_reading()
500
self.assertEqual('', input.getvalue())
501
self.assertEqual('123', output.getvalue())
503
def test_construct_sets_stream_request(self):
504
# constructing a SmartClientStreamMediumRequest on a StreamMedium sets
505
# the current request to the new SmartClientStreamMediumRequest
507
client_medium = medium.SmartSimplePipesClientMedium(
508
None, output, 'base')
509
request = medium.SmartClientStreamMediumRequest(client_medium)
510
self.assertIs(client_medium._current_request, request)
512
def test_construct_while_another_request_active_throws(self):
513
# constructing a SmartClientStreamMediumRequest on a StreamMedium with
514
# a non-None _current_request raises TooManyConcurrentRequests.
516
client_medium = medium.SmartSimplePipesClientMedium(
517
None, output, 'base')
518
client_medium._current_request = "a"
519
self.assertRaises(errors.TooManyConcurrentRequests,
520
medium.SmartClientStreamMediumRequest, client_medium)
522
def test_finished_read_clears_current_request(self):
523
# calling finished_reading clears the current request from the requests
526
client_medium = medium.SmartSimplePipesClientMedium(
527
None, output, 'base')
528
request = medium.SmartClientStreamMediumRequest(client_medium)
529
request.finished_writing()
530
request.finished_reading()
531
self.assertEqual(None, client_medium._current_request)
533
def test_finished_read_before_finished_write_errors(self):
534
# calling finished_reading before calling finished_writing triggers a
535
# WritingNotComplete error.
536
client_medium = medium.SmartSimplePipesClientMedium(
538
request = medium.SmartClientStreamMediumRequest(client_medium)
539
self.assertRaises(errors.WritingNotComplete, request.finished_reading)
541
def test_read_bytes(self):
542
# read bytes should invoke _read_bytes on the stream medium.
543
# we test this by using the SimplePipes medium - the most trivial one
544
# and checking that the data is supplied. Its possible that a
545
# faulty implementation could poke at the pipe variables them selves,
546
# but we trust that this will be caught as it will break the integration
548
input = StringIO('321')
550
client_medium = medium.SmartSimplePipesClientMedium(
551
input, output, 'base')
552
request = medium.SmartClientStreamMediumRequest(client_medium)
553
request.finished_writing()
554
self.assertEqual('321', request.read_bytes(3))
555
request.finished_reading()
556
self.assertEqual('', input.read())
557
self.assertEqual('', output.getvalue())
559
def test_read_bytes_before_finished_write_errors(self):
560
# calling read_bytes before calling finished_writing triggers a
561
# WritingNotComplete error because the Smart protocol is designed to be
562
# compatible with strict message based protocols like HTTP where the
563
# request cannot be submitted until the writing has completed.
564
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
565
request = medium.SmartClientStreamMediumRequest(client_medium)
566
self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
568
def test_read_bytes_after_finished_reading_errors(self):
569
# calling read_bytes after calling finished_reading raises
570
# ReadingCompleted to prevent bad assumptions on stream environments
571
# breaking the needs of message-based environments.
573
client_medium = medium.SmartSimplePipesClientMedium(
574
None, output, 'base')
575
request = medium.SmartClientStreamMediumRequest(client_medium)
576
request.finished_writing()
577
request.finished_reading()
578
self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
581
class RemoteTransportTests(TestCaseWithSmartMedium):
583
def test_plausible_url(self):
584
self.assert_(self.get_url().startswith('bzr://'))
586
def test_probe_transport(self):
587
t = self.get_transport()
588
self.assertIsInstance(t, remote.RemoteTransport)
590
def test_get_medium_from_transport(self):
591
"""Remote transport has a medium always, which it can return."""
592
t = self.get_transport()
593
client_medium = t.get_smart_medium()
594
self.assertIsInstance(client_medium, medium.SmartClientMedium)
597
class ErrorRaisingProtocol(object):
599
def __init__(self, exception):
600
self.exception = exception
602
def next_read_size(self):
606
class SampleRequest(object):
608
def __init__(self, expected_bytes):
609
self.accepted_bytes = ''
610
self._finished_reading = False
611
self.expected_bytes = expected_bytes
612
self.unused_data = ''
614
def accept_bytes(self, bytes):
615
self.accepted_bytes += bytes
616
if self.accepted_bytes.startswith(self.expected_bytes):
617
self._finished_reading = True
618
self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
620
def next_read_size(self):
621
if self._finished_reading:
627
class TestSmartServerStreamMedium(tests.TestCase):
630
super(TestSmartServerStreamMedium, self).setUp()
631
self._captureVar('BZR_NO_SMART_VFS', None)
633
def portable_socket_pair(self):
634
"""Return a pair of TCP sockets connected to each other.
636
Unlike socket.socketpair, this should work on Windows.
638
listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
639
listen_sock.bind(('127.0.0.1', 0))
640
listen_sock.listen(1)
641
client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
642
client_sock.connect(listen_sock.getsockname())
643
server_sock, addr = listen_sock.accept()
645
return server_sock, client_sock
647
def test_smart_query_version(self):
648
"""Feed a canned query version to a server"""
649
# wire-to-wire, using the whole stack
650
to_server = StringIO('hello\n')
651
from_server = StringIO()
652
transport = local.LocalTransport(urlutils.local_path_to_url('/'))
653
server = medium.SmartServerPipeStreamMedium(
654
to_server, from_server, transport)
655
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
657
server._serve_one_request(smart_protocol)
658
self.assertEqual('ok\0012\n',
659
from_server.getvalue())
661
def test_response_to_canned_get(self):
662
transport = memory.MemoryTransport('memory:///')
663
transport.put_bytes('testfile', 'contents\nof\nfile\n')
664
to_server = StringIO('get\001./testfile\n')
665
from_server = StringIO()
666
server = medium.SmartServerPipeStreamMedium(
667
to_server, from_server, transport)
668
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
670
server._serve_one_request(smart_protocol)
671
self.assertEqual('ok\n'
673
'contents\nof\nfile\n'
675
from_server.getvalue())
677
def test_response_to_canned_get_of_utf8(self):
678
# wire-to-wire, using the whole stack, with a UTF-8 filename.
679
transport = memory.MemoryTransport('memory:///')
680
utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
681
transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
682
to_server = StringIO('get\001' + utf8_filename + '\n')
683
from_server = StringIO()
684
server = medium.SmartServerPipeStreamMedium(
685
to_server, from_server, transport)
686
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
688
server._serve_one_request(smart_protocol)
689
self.assertEqual('ok\n'
691
'contents\nof\nfile\n'
693
from_server.getvalue())
695
def test_pipe_like_stream_with_bulk_data(self):
696
sample_request_bytes = 'command\n9\nbulk datadone\n'
697
to_server = StringIO(sample_request_bytes)
698
from_server = StringIO()
699
server = medium.SmartServerPipeStreamMedium(
700
to_server, from_server, None)
701
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
server._serve_one_request(sample_protocol)
703
self.assertEqual('', from_server.getvalue())
704
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
705
self.assertFalse(server.finished)
707
def test_socket_stream_with_bulk_data(self):
708
sample_request_bytes = 'command\n9\nbulk datadone\n'
709
server_sock, client_sock = self.portable_socket_pair()
710
server = medium.SmartServerSocketStreamMedium(
712
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
713
client_sock.sendall(sample_request_bytes)
714
server._serve_one_request(sample_protocol)
716
self.assertEqual('', client_sock.recv(1))
717
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
718
self.assertFalse(server.finished)
720
def test_pipe_like_stream_shutdown_detection(self):
721
to_server = StringIO('')
722
from_server = StringIO()
723
server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
724
server._serve_one_request(SampleRequest('x'))
725
self.assertTrue(server.finished)
727
def test_socket_stream_shutdown_detection(self):
728
server_sock, client_sock = self.portable_socket_pair()
730
server = medium.SmartServerSocketStreamMedium(
732
server._serve_one_request(SampleRequest('x'))
733
self.assertTrue(server.finished)
735
def test_socket_stream_incomplete_request(self):
736
"""The medium should still construct the right protocol version even if
737
the initial read only reads part of the request.
739
Specifically, it should correctly read the protocol version line even
740
if the partial read doesn't end in a newline. An older, naive
741
implementation of _get_line in the server used to have a bug in that
744
incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
745
rest_of_request_bytes = 'lo\n'
746
expected_response = (
747
protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
748
server_sock, client_sock = self.portable_socket_pair()
749
server = medium.SmartServerSocketStreamMedium(
751
client_sock.sendall(incomplete_request_bytes)
752
server_protocol = server._build_protocol()
753
client_sock.sendall(rest_of_request_bytes)
754
server._serve_one_request(server_protocol)
756
self.assertEqual(expected_response, client_sock.recv(50),
757
"Not a version 2 response to 'hello' request.")
758
self.assertEqual('', client_sock.recv(1))
760
def test_pipe_stream_incomplete_request(self):
761
"""The medium should still construct the right protocol version even if
762
the initial read only reads part of the request.
764
Specifically, it should correctly read the protocol version line even
765
if the partial read doesn't end in a newline. An older, naive
766
implementation of _get_line in the server used to have a bug in that
769
incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
770
rest_of_request_bytes = 'lo\n'
771
expected_response = (
772
protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
773
# Make a pair of pipes, to and from the server
774
to_server, to_server_w = os.pipe()
775
from_server_r, from_server = os.pipe()
776
to_server = os.fdopen(to_server, 'r', 0)
777
to_server_w = os.fdopen(to_server_w, 'w', 0)
778
from_server_r = os.fdopen(from_server_r, 'r', 0)
779
from_server = os.fdopen(from_server, 'w', 0)
780
server = medium.SmartServerPipeStreamMedium(
781
to_server, from_server, None)
782
# Like test_socket_stream_incomplete_request, write an incomplete
783
# request (that does not end in '\n') and build a protocol from it.
784
to_server_w.write(incomplete_request_bytes)
785
server_protocol = server._build_protocol()
786
# Send the rest of the request, and finish serving it.
787
to_server_w.write(rest_of_request_bytes)
788
server._serve_one_request(server_protocol)
791
self.assertEqual(expected_response, from_server_r.read(),
792
"Not a version 2 response to 'hello' request.")
793
self.assertEqual('', from_server_r.read(1))
794
from_server_r.close()
797
def test_pipe_like_stream_with_two_requests(self):
798
# If two requests are read in one go, then two calls to
799
# _serve_one_request should still process both of them as if they had
800
# been received seperately.
801
sample_request_bytes = 'command\n'
802
to_server = StringIO(sample_request_bytes * 2)
803
from_server = StringIO()
804
server = medium.SmartServerPipeStreamMedium(
805
to_server, from_server, None)
806
first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
807
server._serve_one_request(first_protocol)
808
self.assertEqual(0, first_protocol.next_read_size())
809
self.assertEqual('', from_server.getvalue())
810
self.assertFalse(server.finished)
811
# Make a new protocol, call _serve_one_request with it to collect the
813
second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
814
server._serve_one_request(second_protocol)
815
self.assertEqual('', from_server.getvalue())
816
self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
817
self.assertFalse(server.finished)
819
def test_socket_stream_with_two_requests(self):
820
# If two requests are read in one go, then two calls to
821
# _serve_one_request should still process both of them as if they had
822
# been received seperately.
823
sample_request_bytes = 'command\n'
824
server_sock, client_sock = self.portable_socket_pair()
825
server = medium.SmartServerSocketStreamMedium(
827
first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
828
# Put two whole requests on the wire.
829
client_sock.sendall(sample_request_bytes * 2)
830
server._serve_one_request(first_protocol)
831
self.assertEqual(0, first_protocol.next_read_size())
832
self.assertFalse(server.finished)
833
# Make a new protocol, call _serve_one_request with it to collect the
835
second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
836
stream_still_open = server._serve_one_request(second_protocol)
837
self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
838
self.assertFalse(server.finished)
840
self.assertEqual('', client_sock.recv(1))
842
def test_pipe_like_stream_error_handling(self):
843
# Use plain python StringIO so we can monkey-patch the close method to
844
# not discard the contents.
845
from StringIO import StringIO
846
to_server = StringIO('')
847
from_server = StringIO()
851
from_server.close = close
852
server = medium.SmartServerPipeStreamMedium(
853
to_server, from_server, None)
854
fake_protocol = ErrorRaisingProtocol(Exception('boom'))
855
server._serve_one_request(fake_protocol)
856
self.assertEqual('', from_server.getvalue())
857
self.assertTrue(self.closed)
858
self.assertTrue(server.finished)
860
def test_socket_stream_error_handling(self):
861
server_sock, client_sock = self.portable_socket_pair()
862
server = medium.SmartServerSocketStreamMedium(
864
fake_protocol = ErrorRaisingProtocol(Exception('boom'))
865
server._serve_one_request(fake_protocol)
866
# recv should not block, because the other end of the socket has been
868
self.assertEqual('', client_sock.recv(1))
869
self.assertTrue(server.finished)
871
def test_pipe_like_stream_keyboard_interrupt_handling(self):
872
to_server = StringIO('')
873
from_server = StringIO()
874
server = medium.SmartServerPipeStreamMedium(
875
to_server, from_server, None)
876
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
878
KeyboardInterrupt, server._serve_one_request, fake_protocol)
879
self.assertEqual('', from_server.getvalue())
881
def test_socket_stream_keyboard_interrupt_handling(self):
882
server_sock, client_sock = self.portable_socket_pair()
883
server = medium.SmartServerSocketStreamMedium(
885
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
887
KeyboardInterrupt, server._serve_one_request, fake_protocol)
889
self.assertEqual('', client_sock.recv(1))
891
def build_protocol_pipe_like(self, bytes):
892
to_server = StringIO(bytes)
893
from_server = StringIO()
894
server = medium.SmartServerPipeStreamMedium(
895
to_server, from_server, None)
896
return server._build_protocol()
898
def build_protocol_socket(self, bytes):
899
server_sock, client_sock = self.portable_socket_pair()
900
server = medium.SmartServerSocketStreamMedium(
902
client_sock.sendall(bytes)
904
return server._build_protocol()
906
def assertProtocolOne(self, server_protocol):
907
# Use assertIs because assertIsInstance will wrongly pass
908
# SmartServerRequestProtocolTwo (because it subclasses
909
# SmartServerRequestProtocolOne).
911
type(server_protocol), protocol.SmartServerRequestProtocolOne)
913
def assertProtocolTwo(self, server_protocol):
914
self.assertIsInstance(
915
server_protocol, protocol.SmartServerRequestProtocolTwo)
917
def test_pipe_like_build_protocol_empty_bytes(self):
918
# Any empty request (i.e. no bytes) is detected as protocol version one.
919
server_protocol = self.build_protocol_pipe_like('')
920
self.assertProtocolOne(server_protocol)
922
def test_socket_like_build_protocol_empty_bytes(self):
923
# Any empty request (i.e. no bytes) is detected as protocol version one.
924
server_protocol = self.build_protocol_socket('')
925
self.assertProtocolOne(server_protocol)
927
def test_pipe_like_build_protocol_non_two(self):
928
# A request that doesn't start with "bzr request 2\n" is version one.
929
server_protocol = self.build_protocol_pipe_like('abc\n')
930
self.assertProtocolOne(server_protocol)
932
def test_socket_build_protocol_non_two(self):
933
# A request that doesn't start with "bzr request 2\n" is version one.
934
server_protocol = self.build_protocol_socket('abc\n')
935
self.assertProtocolOne(server_protocol)
937
def test_pipe_like_build_protocol_two(self):
938
# A request that starts with "bzr request 2\n" is version two.
939
server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
940
self.assertProtocolTwo(server_protocol)
942
def test_socket_build_protocol_two(self):
943
# A request that starts with "bzr request 2\n" is version two.
944
server_protocol = self.build_protocol_socket('bzr request 2\n')
945
self.assertProtocolTwo(server_protocol)
948
class TestGetProtocolFactoryForBytes(tests.TestCase):
949
"""_get_protocol_factory_for_bytes identifies the protocol factory a server
950
should use to decode a given request. Any bytes not part of the version
951
marker string (and thus part of the actual request) are returned alongside
952
the protocol factory.
955
def test_version_three(self):
956
result = medium._get_protocol_factory_for_bytes(
957
'bzr message 3 (bzr 1.6)\nextra bytes')
958
protocol_factory, remainder = result
960
protocol.build_server_protocol_three, protocol_factory)
961
self.assertEqual('extra bytes', remainder)
963
def test_version_two(self):
964
result = medium._get_protocol_factory_for_bytes(
965
'bzr request 2\nextra bytes')
966
protocol_factory, remainder = result
968
protocol.SmartServerRequestProtocolTwo, protocol_factory)
969
self.assertEqual('extra bytes', remainder)
971
def test_version_one(self):
972
"""Version one requests have no version markers."""
973
result = medium._get_protocol_factory_for_bytes('anything\n')
974
protocol_factory, remainder = result
976
protocol.SmartServerRequestProtocolOne, protocol_factory)
977
self.assertEqual('anything\n', remainder)
980
class TestSmartTCPServer(tests.TestCase):
982
def test_get_error_unexpected(self):
983
"""Error reported by server with no specific representation"""
984
self._captureVar('BZR_NO_SMART_VFS', None)
985
class FlakyTransport(object):
987
def external_url(self):
989
def get_bytes(self, path):
990
raise Exception("some random exception from inside server")
991
smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
992
smart_server.start_background_thread('-' + self.id())
994
transport = remote.RemoteTCPTransport(smart_server.get_url())
996
transport.get('something')
997
except errors.TransportError, e:
998
self.assertContainsRe(str(e), 'some random exception')
1000
self.fail("get did not raise expected error")
1001
transport.disconnect()
1003
smart_server.stop_background_thread()
1006
class SmartTCPTests(tests.TestCase):
1007
"""Tests for connection/end to end behaviour using the TCP server.
1009
All of these tests are run with a server running on another thread serving
1010
a MemoryTransport, and a connection to it already open.
1012
the server is obtained by calling self.setUpServer(readonly=False).
1015
def setUpServer(self, readonly=False, backing_transport=None):
1016
"""Setup the server.
1018
:param readonly: Create a readonly server.
1020
if not backing_transport:
1021
self.backing_transport = memory.MemoryTransport()
1023
self.backing_transport = backing_transport
1025
self.real_backing_transport = self.backing_transport
1026
self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
1027
self.server = server.SmartTCPServer(self.backing_transport)
1028
self.server.start_background_thread('-' + self.id())
1029
self.transport = remote.RemoteTCPTransport(self.server.get_url())
1030
self.addCleanup(self.tearDownServer)
1032
def tearDownServer(self):
1033
if getattr(self, 'transport', None):
1034
self.transport.disconnect()
1036
if getattr(self, 'server', None):
1037
self.server.stop_background_thread()
1041
class TestServerSocketUsage(SmartTCPTests):
1043
def test_server_setup_teardown(self):
1044
"""It should be safe to teardown the server with no requests."""
1046
server = self.server
1047
transport = remote.RemoteTCPTransport(self.server.get_url())
1048
self.tearDownServer()
1049
self.assertRaises(errors.ConnectionError, transport.has, '.')
1051
def test_server_closes_listening_sock_on_shutdown_after_request(self):
1052
"""The server should close its listening socket when it's stopped."""
1054
server = self.server
1055
self.transport.has('.')
1056
self.tearDownServer()
1057
# if the listening socket has closed, we should get a BADFD error
1058
# when connecting, rather than a hang.
1059
transport = remote.RemoteTCPTransport(server.get_url())
1060
self.assertRaises(errors.ConnectionError, transport.has, '.')
1063
class WritableEndToEndTests(SmartTCPTests):
1064
"""Client to server tests that require a writable transport."""
1067
super(WritableEndToEndTests, self).setUp()
1070
def test_start_tcp_server(self):
1071
url = self.server.get_url()
1072
self.assertContainsRe(url, r'^bzr://127\.0\.0\.1:[0-9]{2,}/')
1074
def test_smart_transport_has(self):
1075
"""Checking for file existence over smart."""
1076
self._captureVar('BZR_NO_SMART_VFS', None)
1077
self.backing_transport.put_bytes("foo", "contents of foo\n")
1078
self.assertTrue(self.transport.has("foo"))
1079
self.assertFalse(self.transport.has("non-foo"))
1081
def test_smart_transport_get(self):
1082
"""Read back a file over smart."""
1083
self._captureVar('BZR_NO_SMART_VFS', None)
1084
self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
1085
fp = self.transport.get("foo")
1086
self.assertEqual('contents\nof\nfoo\n', fp.read())
1088
def test_get_error_enoent(self):
1089
"""Error reported from server getting nonexistent file."""
1090
# The path in a raised NoSuchFile exception should be the precise path
1091
# asked for by the client. This gives meaningful and unsurprising errors
1093
self._captureVar('BZR_NO_SMART_VFS', None)
1095
self.transport.get('not%20a%20file')
1096
except errors.NoSuchFile, e:
1097
self.assertEqual('not%20a%20file', e.path)
1099
self.fail("get did not raise expected error")
1101
def test_simple_clone_conn(self):
1102
"""Test that cloning reuses the same connection."""
1103
# we create a real connection not a loopback one, but it will use the
1104
# same server and pipes
1105
conn2 = self.transport.clone('.')
1106
self.assertIs(self.transport.get_smart_medium(),
1107
conn2.get_smart_medium())
1109
def test__remote_path(self):
1110
self.assertEquals('/foo/bar',
1111
self.transport._remote_path('foo/bar'))
1113
def test_clone_changes_base(self):
1114
"""Cloning transport produces one with a new base location"""
1115
conn2 = self.transport.clone('subdir')
1116
self.assertEquals(self.transport.base + 'subdir/',
1119
def test_open_dir(self):
1120
"""Test changing directory"""
1121
self._captureVar('BZR_NO_SMART_VFS', None)
1122
transport = self.transport
1123
self.backing_transport.mkdir('toffee')
1124
self.backing_transport.mkdir('toffee/apple')
1125
self.assertEquals('/toffee', transport._remote_path('toffee'))
1126
toffee_trans = transport.clone('toffee')
1127
# Check that each transport has only the contents of its directory
1128
# directly visible. If state was being held in the wrong object, it's
1129
# conceivable that cloning a transport would alter the state of the
1130
# cloned-from transport.
1131
self.assertTrue(transport.has('toffee'))
1132
self.assertFalse(toffee_trans.has('toffee'))
1133
self.assertFalse(transport.has('apple'))
1134
self.assertTrue(toffee_trans.has('apple'))
1136
def test_open_bzrdir(self):
1137
"""Open an existing bzrdir over smart transport"""
1138
transport = self.transport
1139
t = self.backing_transport
1140
bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
1141
result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
1144
class ReadOnlyEndToEndTests(SmartTCPTests):
1145
"""Tests from the client to the server using a readonly backing transport."""
1147
def test_mkdir_error_readonly(self):
1148
"""TransportNotPossible should be preserved from the backing transport."""
1149
self._captureVar('BZR_NO_SMART_VFS', None)
1150
self.setUpServer(readonly=True)
1151
self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1155
class TestServerHooks(SmartTCPTests):
1157
def capture_server_call(self, backing_urls, public_url):
1158
"""Record a server_started|stopped hook firing."""
1159
self.hook_calls.append((backing_urls, public_url))
1161
def test_server_started_hook_memory(self):
1162
"""The server_started hook fires when the server is started."""
1163
self.hook_calls = []
1164
server.SmartTCPServer.hooks.install_named_hook('server_started',
1165
self.capture_server_call, None)
1167
# at this point, the server will be starting a thread up.
1168
# there is no indicator at the moment, so bodge it by doing a request.
1169
self.transport.has('.')
1170
# The default test server uses MemoryTransport and that has no external
1172
self.assertEqual([([self.backing_transport.base], self.transport.base)],
1175
def test_server_started_hook_file(self):
1176
"""The server_started hook fires when the server is started."""
1177
self.hook_calls = []
1178
server.SmartTCPServer.hooks.install_named_hook('server_started',
1179
self.capture_server_call, None)
1180
self.setUpServer(backing_transport=get_transport("."))
1181
# at this point, the server will be starting a thread up.
1182
# there is no indicator at the moment, so bodge it by doing a request.
1183
self.transport.has('.')
1184
# The default test server uses MemoryTransport and that has no external
1186
self.assertEqual([([
1187
self.backing_transport.base, self.backing_transport.external_url()],
1188
self.transport.base)],
1191
def test_server_stopped_hook_simple_memory(self):
1192
"""The server_stopped hook fires when the server is stopped."""
1193
self.hook_calls = []
1194
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1195
self.capture_server_call, None)
1197
result = [([self.backing_transport.base], self.transport.base)]
1198
# check the stopping message isn't emitted up front.
1199
self.assertEqual([], self.hook_calls)
1200
# nor after a single message
1201
self.transport.has('.')
1202
self.assertEqual([], self.hook_calls)
1203
# clean up the server
1204
self.tearDownServer()
1205
# now it should have fired.
1206
self.assertEqual(result, self.hook_calls)
1208
def test_server_stopped_hook_simple_file(self):
1209
"""The server_stopped hook fires when the server is stopped."""
1210
self.hook_calls = []
1211
server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1212
self.capture_server_call, None)
1213
self.setUpServer(backing_transport=get_transport("."))
1215
[self.backing_transport.base, self.backing_transport.external_url()]
1216
, self.transport.base)]
1217
# check the stopping message isn't emitted up front.
1218
self.assertEqual([], self.hook_calls)
1219
# nor after a single message
1220
self.transport.has('.')
1221
self.assertEqual([], self.hook_calls)
1222
# clean up the server
1223
self.tearDownServer()
1224
# now it should have fired.
1225
self.assertEqual(result, self.hook_calls)
1227
# TODO: test that when the server suffers an exception that it calls the
1228
# server-stopped hook.
1231
class SmartServerCommandTests(tests.TestCaseWithTransport):
1232
"""Tests that call directly into the command objects, bypassing the network
1233
and the request dispatching.
1235
Note: these tests are rudimentary versions of the command object tests in
1239
def test_hello(self):
1240
cmd = _mod_request.HelloRequest(None, '/')
1241
response = cmd.execute()
1242
self.assertEqual(('ok', '2'), response.args)
1243
self.assertEqual(None, response.body)
1245
def test_get_bundle(self):
1246
from bzrlib.bundle import serializer
1247
wt = self.make_branch_and_tree('.')
1248
self.build_tree_contents([('hello', 'hello world')])
1250
rev_id = wt.commit('add hello')
1252
cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1253
response = cmd.execute('.', rev_id)
1254
bundle = serializer.read_bundle(StringIO(response.body))
1255
self.assertEqual((), response.args)
1258
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
1259
"""Test that call directly into the handler logic, bypassing the network."""
1262
super(SmartServerRequestHandlerTests, self).setUp()
1263
self._captureVar('BZR_NO_SMART_VFS', None)
1265
def build_handler(self, transport):
1266
"""Returns a handler for the commands in protocol version one."""
1267
return _mod_request.SmartServerRequestHandler(
1268
transport, _mod_request.request_handlers, '/')
1270
def test_construct_request_handler(self):
1271
"""Constructing a request handler should be easy and set defaults."""
1272
handler = _mod_request.SmartServerRequestHandler(None, commands=None,
1273
root_client_path='/')
1274
self.assertFalse(handler.finished_reading)
1276
def test_hello(self):
1277
handler = self.build_handler(None)
1278
handler.dispatch_command('hello', ())
1279
self.assertEqual(('ok', '2'), handler.response.args)
1280
self.assertEqual(None, handler.response.body)
1282
def test_disable_vfs_handler_classes_via_environment(self):
1283
# VFS handler classes will raise an error from "execute" if
1284
# BZR_NO_SMART_VFS is set.
1285
handler = vfs.HasRequest(None, '/')
1286
# set environment variable after construction to make sure it's
1288
# Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
1289
# has called _captureVar, so it will be restored to the right state
1291
os.environ['BZR_NO_SMART_VFS'] = ''
1292
self.assertRaises(errors.DisabledMethod, handler.execute)
1294
def test_readonly_exception_becomes_transport_not_possible(self):
1295
"""The response for a read-only error is ('ReadOnlyError')."""
1296
handler = self.build_handler(self.get_readonly_transport())
1297
# send a mkdir for foo, with no explicit mode - should fail.
1298
handler.dispatch_command('mkdir', ('foo', ''))
1299
# and the failure should be an explicit ReadOnlyError
1300
self.assertEqual(("ReadOnlyError", ), handler.response.args)
1301
# XXX: TODO: test that other TransportNotPossible errors are
1302
# presented as TransportNotPossible - not possible to do that
1303
# until I figure out how to trigger that relatively cleanly via
1304
# the api. RBC 20060918
1306
def test_hello_has_finished_body_on_dispatch(self):
1307
"""The 'hello' command should set finished_reading."""
1308
handler = self.build_handler(None)
1309
handler.dispatch_command('hello', ())
1310
self.assertTrue(handler.finished_reading)
1311
self.assertNotEqual(None, handler.response)
1313
def test_put_bytes_non_atomic(self):
1314
"""'put_...' should set finished_reading after reading the bytes."""
1315
handler = self.build_handler(self.get_transport())
1316
handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
1317
self.assertFalse(handler.finished_reading)
1318
handler.accept_body('1234')
1319
self.assertFalse(handler.finished_reading)
1320
handler.accept_body('5678')
1321
handler.end_of_body()
1322
self.assertTrue(handler.finished_reading)
1323
self.assertEqual(('ok', ), handler.response.args)
1324
self.assertEqual(None, handler.response.body)
1326
def test_readv_accept_body(self):
1327
"""'readv' should set finished_reading after reading offsets."""
1328
self.build_tree(['a-file'])
1329
handler = self.build_handler(self.get_readonly_transport())
1330
handler.dispatch_command('readv', ('a-file', ))
1331
self.assertFalse(handler.finished_reading)
1332
handler.accept_body('2,')
1333
self.assertFalse(handler.finished_reading)
1334
handler.accept_body('3')
1335
handler.end_of_body()
1336
self.assertTrue(handler.finished_reading)
1337
self.assertEqual(('readv', ), handler.response.args)
1338
# co - nte - nt of a-file is the file contents we are extracting from.
1339
self.assertEqual('nte', handler.response.body)
1341
def test_readv_short_read_response_contents(self):
1342
"""'readv' when a short read occurs sets the response appropriately."""
1343
self.build_tree(['a-file'])
1344
handler = self.build_handler(self.get_readonly_transport())
1345
handler.dispatch_command('readv', ('a-file', ))
1346
# read beyond the end of the file.
1347
handler.accept_body('100,1')
1348
handler.end_of_body()
1349
self.assertTrue(handler.finished_reading)
1350
self.assertEqual(('ShortReadvError', './a-file', '100', '1', '0'),
1351
handler.response.args)
1352
self.assertEqual(None, handler.response.body)
1355
class RemoteTransportRegistration(tests.TestCase):
1357
def test_registration(self):
1358
t = get_transport('bzr+ssh://example.com/path')
1359
self.assertIsInstance(t, remote.RemoteSSHTransport)
1360
self.assertEqual('example.com', t._host)
1362
def test_bzr_https(self):
1363
# https://bugs.launchpad.net/bzr/+bug/128456
1364
t = get_transport('bzr+https://example.com/path')
1365
self.assertIsInstance(t, remote.RemoteHTTPTransport)
1366
self.assertStartsWith(
1367
t._http_transport.base,
1371
class TestRemoteTransport(tests.TestCase):
1373
def test_use_connection_factory(self):
1374
# We want to be able to pass a client as a parameter to RemoteTransport.
1375
input = StringIO('ok\n3\nbardone\n')
1377
client_medium = medium.SmartSimplePipesClientMedium(
1378
input, output, 'base')
1379
transport = remote.RemoteTransport(
1380
'bzr://localhost/', medium=client_medium)
1381
# Disable version detection.
1382
client_medium._protocol_version = 1
1384
# We want to make sure the client is used when the first remote
1385
# method is called. No data should have been sent, or read.
1386
self.assertEqual(0, input.tell())
1387
self.assertEqual('', output.getvalue())
1389
# Now call a method that should result in one request: as the
1390
# transport makes its own protocol instances, we check on the wire.
1391
# XXX: TODO: give the transport a protocol factory, which can make
1392
# an instrumented protocol for us.
1393
self.assertEqual('bar', transport.get_bytes('foo'))
1394
# only the needed data should have been sent/received.
1395
self.assertEqual(13, input.tell())
1396
self.assertEqual('get\x01/foo\n', output.getvalue())
1398
def test__translate_error_readonly(self):
1399
"""Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1400
client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
1401
transport = remote.RemoteTransport(
1402
'bzr://localhost/', medium=client_medium)
1403
self.assertRaises(errors.TransportNotPossible,
1404
transport._translate_error, ("ReadOnlyError", ))
1407
class TestSmartProtocol(tests.TestCase):
1408
"""Base class for smart protocol tests.
1410
Each test case gets a smart_server and smart_client created during setUp().
1412
It is planned that the client can be called with self.call_client() giving
1413
it an expected server response, which will be fed into it when it tries to
1414
read. Likewise, self.call_server will call a servers method with a canned
1415
serialised client request. Output done by the client or server for these
1416
calls will be captured to self.to_server and self.to_client. Each element
1417
in the list is a write call from the client or server respectively.
1419
Subclasses can override client_protocol_class and server_protocol_class.
1422
request_encoder = None
1423
response_decoder = None
1424
server_protocol_class = None
1425
client_protocol_class = None
1427
def make_client_protocol_and_output(self, input_bytes=None):
1431
# This is very similar to
1432
# bzrlib.smart.client._SmartClient._build_client_protocol
1433
# XXX: make this use _SmartClient!
1434
if input_bytes is None:
1437
input = StringIO(input_bytes)
1439
client_medium = medium.SmartSimplePipesClientMedium(
1440
input, output, 'base')
1441
request = client_medium.get_request()
1442
if self.client_protocol_class is not None:
1443
client_protocol = self.client_protocol_class(request)
1444
return client_protocol, client_protocol, output
1446
self.assertNotEqual(None, self.request_encoder)
1447
self.assertNotEqual(None, self.response_decoder)
1448
requester = self.request_encoder(request)
1449
response_handler = message.ConventionalResponseHandler()
1450
response_protocol = self.response_decoder(
1451
response_handler, expect_version_marker=True)
1452
response_handler.setProtoAndMediumRequest(
1453
response_protocol, request)
1454
return requester, response_handler, output
1456
def make_client_protocol(self, input_bytes=None):
1457
result = self.make_client_protocol_and_output(input_bytes=input_bytes)
1458
requester, response_handler, output = result
1459
return requester, response_handler
1461
def make_server_protocol(self):
1462
out_stream = StringIO()
1463
smart_protocol = self.server_protocol_class(None, out_stream.write)
1464
return smart_protocol, out_stream
1467
super(TestSmartProtocol, self).setUp()
1468
self.response_marker = getattr(
1469
self.client_protocol_class, 'response_marker', None)
1470
self.request_marker = getattr(
1471
self.client_protocol_class, 'request_marker', None)
1473
def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1475
"""Check that smart (de)serialises offsets as expected.
1477
We check both serialisation and deserialisation at the same time
1478
to ensure that the round tripping cannot skew: both directions should
1481
:param expected_offsets: a readv offset list.
1482
:param expected_seralised: an expected serial form of the offsets.
1484
# XXX: '_deserialise_offsets' should be a method of the
1485
# SmartServerRequestProtocol in future.
1486
readv_cmd = vfs.ReadvRequest(None, '/')
1487
offsets = readv_cmd._deserialise_offsets(expected_serialised)
1488
self.assertEqual(expected_offsets, offsets)
1489
serialised = requester._serialise_offsets(offsets)
1490
self.assertEqual(expected_serialised, serialised)
1492
def build_protocol_waiting_for_body(self):
1493
smart_protocol, out_stream = self.make_server_protocol()
1494
smart_protocol._has_dispatched = True
1495
smart_protocol.request = _mod_request.SmartServerRequestHandler(
1496
None, _mod_request.request_handlers, '/')
1497
class FakeCommand(object):
1498
def do_body(cmd, body_bytes):
1499
self.end_received = True
1500
self.assertEqual('abcdefg', body_bytes)
1501
return _mod_request.SuccessfulSmartServerResponse(('ok', ))
1502
smart_protocol.request._command = FakeCommand()
1503
# Call accept_bytes to make sure that internal state like _body_decoder
1504
# is initialised. This test should probably be given a clearer
1505
# interface to work with that will not cause this inconsistency.
1506
# -- Andrew Bennetts, 2006-09-28
1507
smart_protocol.accept_bytes('')
1508
return smart_protocol
1510
def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1512
"""Assert that each input_tuple serialises as expected_bytes, and the
1513
bytes deserialise as expected_tuple.
1515
# check the encoding of the server for all input_tuples matches
1517
for input_tuple in input_tuples:
1518
server_protocol, server_output = self.make_server_protocol()
1519
server_protocol._send_response(
1520
_mod_request.SuccessfulSmartServerResponse(input_tuple))
1521
self.assertEqual(expected_bytes, server_output.getvalue())
1522
# check the decoding of the client smart_protocol from expected_bytes:
1523
requester, response_handler = self.make_client_protocol(expected_bytes)
1524
requester.call('foo')
1525
self.assertEqual(expected_tuple, response_handler.read_response_tuple())
1528
class CommonSmartProtocolTestMixin(object):
1530
def test_connection_closed_reporting(self):
1531
requester, response_handler = self.make_client_protocol()
1532
requester.call('hello')
1533
ex = self.assertRaises(errors.ConnectionReset,
1534
response_handler.read_response_tuple)
1535
self.assertEqual("Connection closed: "
1536
"please check connectivity and permissions "
1537
"(and try -Dhpss if further diagnosis is required)", str(ex))
1539
def test_server_offset_serialisation(self):
1540
"""The Smart protocol serialises offsets as a comma and \n string.
1542
We check a number of boundary cases are as expected: empty, one offset,
1543
one with the order of reads not increasing (an out of order read), and
1544
one that should coalesce.
1546
requester, response_handler = self.make_client_protocol()
1547
self.assertOffsetSerialisation([], '', requester)
1548
self.assertOffsetSerialisation([(1,2)], '1,2', requester)
1549
self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1551
self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1552
'1,2\n3,4\n100,200', requester)
1555
class TestVersionOneFeaturesInProtocolOne(
1556
TestSmartProtocol, CommonSmartProtocolTestMixin):
1557
"""Tests for version one smart protocol features as implemeted by version
1560
client_protocol_class = protocol.SmartClientRequestProtocolOne
1561
server_protocol_class = protocol.SmartServerRequestProtocolOne
1563
def test_construct_version_one_server_protocol(self):
1564
smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1565
self.assertEqual('', smart_protocol.unused_data)
1566
self.assertEqual('', smart_protocol.in_buffer)
1567
self.assertFalse(smart_protocol._has_dispatched)
1568
self.assertEqual(1, smart_protocol.next_read_size())
1570
def test_construct_version_one_client_protocol(self):
1571
# we can construct a client protocol from a client medium request
1573
client_medium = medium.SmartSimplePipesClientMedium(
1574
None, output, 'base')
1575
request = client_medium.get_request()
1576
client_protocol = protocol.SmartClientRequestProtocolOne(request)
1578
def test_accept_bytes_of_bad_request_to_protocol(self):
1579
out_stream = StringIO()
1580
smart_protocol = protocol.SmartServerRequestProtocolOne(
1581
None, out_stream.write)
1582
smart_protocol.accept_bytes('abc')
1583
self.assertEqual('abc', smart_protocol.in_buffer)
1584
smart_protocol.accept_bytes('\n')
1586
"error\x01Generic bzr smart protocol error: bad request 'abc'\n",
1587
out_stream.getvalue())
1588
self.assertTrue(smart_protocol._has_dispatched)
1589
self.assertEqual(0, smart_protocol.next_read_size())
1591
def test_accept_body_bytes_to_protocol(self):
1592
protocol = self.build_protocol_waiting_for_body()
1593
self.assertEqual(6, protocol.next_read_size())
1594
protocol.accept_bytes('7\nabc')
1595
self.assertEqual(9, protocol.next_read_size())
1596
protocol.accept_bytes('defgd')
1597
protocol.accept_bytes('one\n')
1598
self.assertEqual(0, protocol.next_read_size())
1599
self.assertTrue(self.end_received)
1601
def test_accept_request_and_body_all_at_once(self):
1602
self._captureVar('BZR_NO_SMART_VFS', None)
1603
mem_transport = memory.MemoryTransport()
1604
mem_transport.put_bytes('foo', 'abcdefghij')
1605
out_stream = StringIO()
1606
smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
1608
smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1609
self.assertEqual(0, smart_protocol.next_read_size())
1610
self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1611
self.assertEqual('', smart_protocol.unused_data)
1612
self.assertEqual('', smart_protocol.in_buffer)
1614
def test_accept_excess_bytes_are_preserved(self):
1615
out_stream = StringIO()
1616
smart_protocol = protocol.SmartServerRequestProtocolOne(
1617
None, out_stream.write)
1618
smart_protocol.accept_bytes('hello\nhello\n')
1619
self.assertEqual("ok\x012\n", out_stream.getvalue())
1620
self.assertEqual("hello\n", smart_protocol.unused_data)
1621
self.assertEqual("", smart_protocol.in_buffer)
1623
def test_accept_excess_bytes_after_body(self):
1624
protocol = self.build_protocol_waiting_for_body()
1625
protocol.accept_bytes('7\nabcdefgdone\nX')
1626
self.assertTrue(self.end_received)
1627
self.assertEqual("X", protocol.unused_data)
1628
self.assertEqual("", protocol.in_buffer)
1629
protocol.accept_bytes('Y')
1630
self.assertEqual("XY", protocol.unused_data)
1631
self.assertEqual("", protocol.in_buffer)
1633
def test_accept_excess_bytes_after_dispatch(self):
1634
out_stream = StringIO()
1635
smart_protocol = protocol.SmartServerRequestProtocolOne(
1636
None, out_stream.write)
1637
smart_protocol.accept_bytes('hello\n')
1638
self.assertEqual("ok\x012\n", out_stream.getvalue())
1639
smart_protocol.accept_bytes('hel')
1640
self.assertEqual("hel", smart_protocol.unused_data)
1641
smart_protocol.accept_bytes('lo\n')
1642
self.assertEqual("hello\n", smart_protocol.unused_data)
1643
self.assertEqual("", smart_protocol.in_buffer)
1645
def test__send_response_sets_finished_reading(self):
1646
smart_protocol = protocol.SmartServerRequestProtocolOne(
1647
None, lambda x: None)
1648
self.assertEqual(1, smart_protocol.next_read_size())
1649
smart_protocol._send_response(
1650
_mod_request.SuccessfulSmartServerResponse(('x',)))
1651
self.assertEqual(0, smart_protocol.next_read_size())
1653
def test__send_response_errors_with_base_response(self):
1654
"""Ensure that only the Successful/Failed subclasses are used."""
1655
smart_protocol = protocol.SmartServerRequestProtocolOne(
1656
None, lambda x: None)
1657
self.assertRaises(AttributeError, smart_protocol._send_response,
1658
_mod_request.SmartServerResponse(('x',)))
1660
def test_query_version(self):
1661
"""query_version on a SmartClientProtocolOne should return a number.
1663
The protocol provides the query_version because the domain level clients
1664
may all need to be able to probe for capabilities.
1666
# What we really want to test here is that SmartClientProtocolOne calls
1667
# accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1668
# response of tuple-encoded (ok, 1). Also, seperately we should test
1669
# the error if the response is a non-understood version.
1670
input = StringIO('ok\x012\n')
1672
client_medium = medium.SmartSimplePipesClientMedium(
1673
input, output, 'base')
1674
request = client_medium.get_request()
1675
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1676
self.assertEqual(2, smart_protocol.query_version())
1678
def test_client_call_empty_response(self):
1679
# protocol.call() can get back an empty tuple as a response. This occurs
1680
# when the parsed line is an empty line, and results in a tuple with
1681
# one element - an empty string.
1682
self.assertServerToClientEncoding('\n', ('', ), [(), ('', )])
1684
def test_client_call_three_element_response(self):
1685
# protocol.call() can get back tuples of other lengths. A three element
1686
# tuple should be unpacked as three strings.
1687
self.assertServerToClientEncoding('a\x01b\x0134\n', ('a', 'b', '34'),
1690
def test_client_call_with_body_bytes_uploads(self):
1691
# protocol.call_with_body_bytes should length-prefix the bytes onto the
1693
expected_bytes = "foo\n7\nabcdefgdone\n"
1694
input = StringIO("\n")
1696
client_medium = medium.SmartSimplePipesClientMedium(
1697
input, output, 'base')
1698
request = client_medium.get_request()
1699
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1700
smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1701
self.assertEqual(expected_bytes, output.getvalue())
1703
def test_client_call_with_body_readv_array(self):
1704
# protocol.call_with_upload should encode the readv array and then
1705
# length-prefix the bytes onto the wire.
1706
expected_bytes = "foo\n7\n1,2\n5,6done\n"
1707
input = StringIO("\n")
1709
client_medium = medium.SmartSimplePipesClientMedium(
1710
input, output, 'base')
1711
request = client_medium.get_request()
1712
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1713
smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1714
self.assertEqual(expected_bytes, output.getvalue())
1716
def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
1718
input = StringIO(server_bytes)
1720
client_medium = medium.SmartSimplePipesClientMedium(
1721
input, output, 'base')
1722
request = client_medium.get_request()
1723
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1724
smart_protocol.call('foo')
1726
errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
1727
# The request has been finished. There is no body to read, and
1728
# attempts to read one will fail.
1730
errors.ReadingCompleted, smart_protocol.read_body_bytes)
1732
def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
1733
"""read_response_tuple raises UnknownSmartMethod if the response says
1734
the server did not recognise the request.
1737
"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
1738
self._test_client_read_response_tuple_raises_UnknownSmartMethod(
1741
def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self):
1742
"""read_response_tuple also raises UnknownSmartMethod if the response
1743
from a bzr 0.11 says the server did not recognise the request.
1745
(bzr 0.11 sends a slightly different error message to later versions.)
1748
"error\x01Generic bzr smart protocol error: bad request u'foo'\n")
1749
self._test_client_read_response_tuple_raises_UnknownSmartMethod(
1752
def test_client_read_body_bytes_all(self):
1753
# read_body_bytes should decode the body bytes from the wire into
1755
expected_bytes = "1234567"
1756
server_bytes = "ok\n7\n1234567done\n"
1757
input = StringIO(server_bytes)
1759
client_medium = medium.SmartSimplePipesClientMedium(
1760
input, output, 'base')
1761
request = client_medium.get_request()
1762
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1763
smart_protocol.call('foo')
1764
smart_protocol.read_response_tuple(True)
1765
self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1767
def test_client_read_body_bytes_incremental(self):
1768
# test reading a few bytes at a time from the body
1769
# XXX: possibly we should test dribbling the bytes into the stringio
1770
# to make the state machine work harder: however, as we use the
1771
# LengthPrefixedBodyDecoder that is already well tested - we can skip
1773
expected_bytes = "1234567"
1774
server_bytes = "ok\n7\n1234567done\n"
1775
input = StringIO(server_bytes)
1777
client_medium = medium.SmartSimplePipesClientMedium(
1778
input, output, 'base')
1779
request = client_medium.get_request()
1780
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1781
smart_protocol.call('foo')
1782
smart_protocol.read_response_tuple(True)
1783
self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
1784
self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
1785
self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
1786
self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1788
def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1789
# cancelling the expected body needs to finish the request, but not
1790
# read any more bytes.
1791
expected_bytes = "1234567"
1792
server_bytes = "ok\n7\n1234567done\n"
1793
input = StringIO(server_bytes)
1795
client_medium = medium.SmartSimplePipesClientMedium(
1796
input, output, 'base')
1797
request = client_medium.get_request()
1798
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1799
smart_protocol.call('foo')
1800
smart_protocol.read_response_tuple(True)
1801
smart_protocol.cancel_read_body()
1802
self.assertEqual(3, input.tell())
1804
errors.ReadingCompleted, smart_protocol.read_body_bytes)
1806
def test_client_read_body_bytes_interrupted_connection(self):
1807
server_bytes = "ok\n999\nincomplete body"
1808
input = StringIO(server_bytes)
1810
client_medium = medium.SmartSimplePipesClientMedium(
1811
input, output, 'base')
1812
request = client_medium.get_request()
1813
smart_protocol = self.client_protocol_class(request)
1814
smart_protocol.call('foo')
1815
smart_protocol.read_response_tuple(True)
1817
errors.ConnectionReset, smart_protocol.read_body_bytes)
1820
class TestVersionOneFeaturesInProtocolTwo(
1821
TestSmartProtocol, CommonSmartProtocolTestMixin):
1822
"""Tests for version one smart protocol features as implemeted by version
1826
client_protocol_class = protocol.SmartClientRequestProtocolTwo
1827
server_protocol_class = protocol.SmartServerRequestProtocolTwo
1829
def test_construct_version_two_server_protocol(self):
1830
smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
1831
self.assertEqual('', smart_protocol.unused_data)
1832
self.assertEqual('', smart_protocol.in_buffer)
1833
self.assertFalse(smart_protocol._has_dispatched)
1834
self.assertEqual(1, smart_protocol.next_read_size())
1836
def test_construct_version_two_client_protocol(self):
1837
# we can construct a client protocol from a client medium request
1839
client_medium = medium.SmartSimplePipesClientMedium(
1840
None, output, 'base')
1841
request = client_medium.get_request()
1842
client_protocol = protocol.SmartClientRequestProtocolTwo(request)
1844
def test_accept_bytes_of_bad_request_to_protocol(self):
1845
out_stream = StringIO()
1846
smart_protocol = self.server_protocol_class(None, out_stream.write)
1847
smart_protocol.accept_bytes('abc')
1848
self.assertEqual('abc', smart_protocol.in_buffer)
1849
smart_protocol.accept_bytes('\n')
1851
self.response_marker +
1852
"failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
1853
out_stream.getvalue())
1854
self.assertTrue(smart_protocol._has_dispatched)
1855
self.assertEqual(0, smart_protocol.next_read_size())
1857
def test_accept_body_bytes_to_protocol(self):
1858
protocol = self.build_protocol_waiting_for_body()
1859
self.assertEqual(6, protocol.next_read_size())
1860
protocol.accept_bytes('7\nabc')
1861
self.assertEqual(9, protocol.next_read_size())
1862
protocol.accept_bytes('defgd')
1863
protocol.accept_bytes('one\n')
1864
self.assertEqual(0, protocol.next_read_size())
1865
self.assertTrue(self.end_received)
1867
def test_accept_request_and_body_all_at_once(self):
1868
self._captureVar('BZR_NO_SMART_VFS', None)
1869
mem_transport = memory.MemoryTransport()
1870
mem_transport.put_bytes('foo', 'abcdefghij')
1871
out_stream = StringIO()
1872
smart_protocol = self.server_protocol_class(
1873
mem_transport, out_stream.write)
1874
smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1875
self.assertEqual(0, smart_protocol.next_read_size())
1876
self.assertEqual(self.response_marker +
1877
'success\nreadv\n3\ndefdone\n',
1878
out_stream.getvalue())
1879
self.assertEqual('', smart_protocol.unused_data)
1880
self.assertEqual('', smart_protocol.in_buffer)
1882
def test_accept_excess_bytes_are_preserved(self):
1883
out_stream = StringIO()
1884
smart_protocol = self.server_protocol_class(None, out_stream.write)
1885
smart_protocol.accept_bytes('hello\nhello\n')
1886
self.assertEqual(self.response_marker + "success\nok\x012\n",
1887
out_stream.getvalue())
1888
self.assertEqual("hello\n", smart_protocol.unused_data)
1889
self.assertEqual("", smart_protocol.in_buffer)
1891
def test_accept_excess_bytes_after_body(self):
1892
# The excess bytes look like the start of another request.
1893
server_protocol = self.build_protocol_waiting_for_body()
1894
server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
1895
self.assertTrue(self.end_received)
1896
self.assertEqual(self.response_marker,
1897
server_protocol.unused_data)
1898
self.assertEqual("", server_protocol.in_buffer)
1899
server_protocol.accept_bytes('Y')
1900
self.assertEqual(self.response_marker + "Y",
1901
server_protocol.unused_data)
1902
self.assertEqual("", server_protocol.in_buffer)
1904
def test_accept_excess_bytes_after_dispatch(self):
1905
out_stream = StringIO()
1906
smart_protocol = self.server_protocol_class(None, out_stream.write)
1907
smart_protocol.accept_bytes('hello\n')
1908
self.assertEqual(self.response_marker + "success\nok\x012\n",
1909
out_stream.getvalue())
1910
smart_protocol.accept_bytes(self.request_marker + 'hel')
1911
self.assertEqual(self.request_marker + "hel",
1912
smart_protocol.unused_data)
1913
smart_protocol.accept_bytes('lo\n')
1914
self.assertEqual(self.request_marker + "hello\n",
1915
smart_protocol.unused_data)
1916
self.assertEqual("", smart_protocol.in_buffer)
1918
def test__send_response_sets_finished_reading(self):
1919
smart_protocol = self.server_protocol_class(None, lambda x: None)
1920
self.assertEqual(1, smart_protocol.next_read_size())
1921
smart_protocol._send_response(
1922
_mod_request.SuccessfulSmartServerResponse(('x',)))
1923
self.assertEqual(0, smart_protocol.next_read_size())
1925
def test__send_response_errors_with_base_response(self):
1926
"""Ensure that only the Successful/Failed subclasses are used."""
1927
smart_protocol = self.server_protocol_class(None, lambda x: None)
1928
self.assertRaises(AttributeError, smart_protocol._send_response,
1929
_mod_request.SmartServerResponse(('x',)))
1931
def test_query_version(self):
1932
"""query_version on a SmartClientProtocolTwo should return a number.
1934
The protocol provides the query_version because the domain level clients
1935
may all need to be able to probe for capabilities.
1937
# What we really want to test here is that SmartClientProtocolTwo calls
1938
# accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1939
# response of tuple-encoded (ok, 1). Also, seperately we should test
1940
# the error if the response is a non-understood version.
1941
input = StringIO(self.response_marker + 'success\nok\x012\n')
1943
client_medium = medium.SmartSimplePipesClientMedium(
1944
input, output, 'base')
1945
request = client_medium.get_request()
1946
smart_protocol = self.client_protocol_class(request)
1947
self.assertEqual(2, smart_protocol.query_version())
1949
def test_client_call_empty_response(self):
1950
# protocol.call() can get back an empty tuple as a response. This occurs
1951
# when the parsed line is an empty line, and results in a tuple with
1952
# one element - an empty string.
1953
self.assertServerToClientEncoding(
1954
self.response_marker + 'success\n\n', ('', ), [(), ('', )])
1956
def test_client_call_three_element_response(self):
1957
# protocol.call() can get back tuples of other lengths. A three element
1958
# tuple should be unpacked as three strings.
1959
self.assertServerToClientEncoding(
1960
self.response_marker + 'success\na\x01b\x0134\n',
1964
def test_client_call_with_body_bytes_uploads(self):
1965
# protocol.call_with_body_bytes should length-prefix the bytes onto the
1967
expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
1968
input = StringIO("\n")
1970
client_medium = medium.SmartSimplePipesClientMedium(
1971
input, output, 'base')
1972
request = client_medium.get_request()
1973
smart_protocol = self.client_protocol_class(request)
1974
smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1975
self.assertEqual(expected_bytes, output.getvalue())
1977
def test_client_call_with_body_readv_array(self):
1978
# protocol.call_with_upload should encode the readv array and then
1979
# length-prefix the bytes onto the wire.
1980
expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
1981
input = StringIO("\n")
1983
client_medium = medium.SmartSimplePipesClientMedium(
1984
input, output, 'base')
1985
request = client_medium.get_request()
1986
smart_protocol = self.client_protocol_class(request)
1987
smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1988
self.assertEqual(expected_bytes, output.getvalue())
1990
def test_client_read_body_bytes_all(self):
1991
# read_body_bytes should decode the body bytes from the wire into
1993
expected_bytes = "1234567"
1994
server_bytes = (self.response_marker +
1995
"success\nok\n7\n1234567done\n")
1996
input = StringIO(server_bytes)
1998
client_medium = medium.SmartSimplePipesClientMedium(
1999
input, output, 'base')
2000
request = client_medium.get_request()
2001
smart_protocol = self.client_protocol_class(request)
2002
smart_protocol.call('foo')
2003
smart_protocol.read_response_tuple(True)
2004
self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
2006
def test_client_read_body_bytes_incremental(self):
2007
# test reading a few bytes at a time from the body
2008
# XXX: possibly we should test dribbling the bytes into the stringio
2009
# to make the state machine work harder: however, as we use the
2010
# LengthPrefixedBodyDecoder that is already well tested - we can skip
2012
expected_bytes = "1234567"
2013
server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
2014
input = StringIO(server_bytes)
2016
client_medium = medium.SmartSimplePipesClientMedium(
2017
input, output, 'base')
2018
request = client_medium.get_request()
2019
smart_protocol = self.client_protocol_class(request)
2020
smart_protocol.call('foo')
2021
smart_protocol.read_response_tuple(True)
2022
self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
2023
self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
2024
self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
2025
self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
2027
def test_client_cancel_read_body_does_not_eat_body_bytes(self):
2028
# cancelling the expected body needs to finish the request, but not
2029
# read any more bytes.
2030
server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
2031
input = StringIO(server_bytes)
2033
client_medium = medium.SmartSimplePipesClientMedium(
2034
input, output, 'base')
2035
request = client_medium.get_request()
2036
smart_protocol = self.client_protocol_class(request)
2037
smart_protocol.call('foo')
2038
smart_protocol.read_response_tuple(True)
2039
smart_protocol.cancel_read_body()
2040
self.assertEqual(len(self.response_marker + 'success\nok\n'),
2043
errors.ReadingCompleted, smart_protocol.read_body_bytes)
2045
def test_client_read_body_bytes_interrupted_connection(self):
2046
server_bytes = (self.response_marker +
2047
"success\nok\n999\nincomplete body")
2048
input = StringIO(server_bytes)
2050
client_medium = medium.SmartSimplePipesClientMedium(
2051
input, output, 'base')
2052
request = client_medium.get_request()
2053
smart_protocol = self.client_protocol_class(request)
2054
smart_protocol.call('foo')
2055
smart_protocol.read_response_tuple(True)
2057
errors.ConnectionReset, smart_protocol.read_body_bytes)
2060
class TestSmartProtocolTwoSpecificsMixin(object):
2062
def assertBodyStreamSerialisation(self, expected_serialisation,
2064
"""Assert that body_stream is serialised as expected_serialisation."""
2065
out_stream = StringIO()
2066
protocol._send_stream(body_stream, out_stream.write)
2067
self.assertEqual(expected_serialisation, out_stream.getvalue())
2069
def assertBodyStreamRoundTrips(self, body_stream):
2070
"""Assert that body_stream is the same after being serialised and
2073
out_stream = StringIO()
2074
protocol._send_stream(body_stream, out_stream.write)
2075
decoder = protocol.ChunkedBodyDecoder()
2076
decoder.accept_bytes(out_stream.getvalue())
2077
decoded_stream = list(iter(decoder.read_next_chunk, None))
2078
self.assertEqual(body_stream, decoded_stream)
2080
def test_body_stream_serialisation_empty(self):
2081
"""A body_stream with no bytes can be serialised."""
2082
self.assertBodyStreamSerialisation('chunked\nEND\n', [])
2083
self.assertBodyStreamRoundTrips([])
2085
def test_body_stream_serialisation(self):
2086
stream = ['chunk one', 'chunk two', 'chunk three']
2087
self.assertBodyStreamSerialisation(
2088
'chunked\n' + '9\nchunk one' + '9\nchunk two' + 'b\nchunk three' +
2091
self.assertBodyStreamRoundTrips(stream)
2093
def test_body_stream_with_empty_element_serialisation(self):
2094
"""A body stream can include ''.
2096
The empty string can be transmitted like any other string.
2098
stream = ['', 'chunk']
2099
self.assertBodyStreamSerialisation(
2100
'chunked\n' + '0\n' + '5\nchunk' + 'END\n', stream)
2101
self.assertBodyStreamRoundTrips(stream)
2103
def test_body_stream_error_serialistion(self):
2104
stream = ['first chunk',
2105
_mod_request.FailedSmartServerResponse(
2106
('FailureName', 'failure arg'))]
2108
'chunked\n' + 'b\nfirst chunk' +
2109
'ERR\n' + 'b\nFailureName' + 'b\nfailure arg' +
2111
self.assertBodyStreamSerialisation(expected_bytes, stream)
2112
self.assertBodyStreamRoundTrips(stream)
2114
def test__send_response_includes_failure_marker(self):
2115
"""FailedSmartServerResponse have 'failed\n' after the version."""
2116
out_stream = StringIO()
2117
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2118
None, out_stream.write)
2119
smart_protocol._send_response(
2120
_mod_request.FailedSmartServerResponse(('x',)))
2121
self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
2122
out_stream.getvalue())
2124
def test__send_response_includes_success_marker(self):
2125
"""SuccessfulSmartServerResponse have 'success\n' after the version."""
2126
out_stream = StringIO()
2127
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2128
None, out_stream.write)
2129
smart_protocol._send_response(
2130
_mod_request.SuccessfulSmartServerResponse(('x',)))
2131
self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
2132
out_stream.getvalue())
2134
def test__send_response_with_body_stream_sets_finished_reading(self):
2135
smart_protocol = protocol.SmartServerRequestProtocolTwo(
2136
None, lambda x: None)
2137
self.assertEqual(1, smart_protocol.next_read_size())
2138
smart_protocol._send_response(
2139
_mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
2140
self.assertEqual(0, smart_protocol.next_read_size())
2142
def test_streamed_body_bytes(self):
2143
body_header = 'chunked\n'
2144
two_body_chunks = "4\n1234" + "3\n567"
2145
body_terminator = "END\n"
2146
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2147
"success\nok\n" + body_header + two_body_chunks +
2149
input = StringIO(server_bytes)
2151
client_medium = medium.SmartSimplePipesClientMedium(
2152
input, output, 'base')
2153
request = client_medium.get_request()
2154
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2155
smart_protocol.call('foo')
2156
smart_protocol.read_response_tuple(True)
2157
stream = smart_protocol.read_streamed_body()
2158
self.assertEqual(['1234', '567'], list(stream))
2160
def test_read_streamed_body_error(self):
2161
"""When a stream is interrupted by an error..."""
2162
body_header = 'chunked\n'
2163
a_body_chunk = '4\naaaa'
2164
err_signal = 'ERR\n'
2165
err_chunks = 'a\nerror arg1' + '4\narg2'
2167
body = body_header + a_body_chunk + err_signal + err_chunks + finish
2168
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2169
"success\nok\n" + body)
2170
input = StringIO(server_bytes)
2172
client_medium = medium.SmartSimplePipesClientMedium(
2173
input, output, 'base')
2174
smart_request = client_medium.get_request()
2175
smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
2176
smart_protocol.call('foo')
2177
smart_protocol.read_response_tuple(True)
2180
_mod_request.FailedSmartServerResponse(('error arg1', 'arg2'))]
2181
stream = smart_protocol.read_streamed_body()
2182
self.assertEqual(expected_chunks, list(stream))
2184
def test_streamed_body_bytes_interrupted_connection(self):
2185
body_header = 'chunked\n'
2186
incomplete_body_chunk = "9999\nincomplete chunk"
2187
server_bytes = (protocol.RESPONSE_VERSION_TWO +
2188
"success\nok\n" + body_header + incomplete_body_chunk)
2189
input = StringIO(server_bytes)
2191
client_medium = medium.SmartSimplePipesClientMedium(
2192
input, output, 'base')
2193
request = client_medium.get_request()
2194
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2195
smart_protocol.call('foo')
2196
smart_protocol.read_response_tuple(True)
2197
stream = smart_protocol.read_streamed_body()
2198
self.assertRaises(errors.ConnectionReset, stream.next)
2200
def test_client_read_response_tuple_sets_response_status(self):
2201
server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
2202
input = StringIO(server_bytes)
2204
client_medium = medium.SmartSimplePipesClientMedium(
2205
input, output, 'base')
2206
request = client_medium.get_request()
2207
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2208
smart_protocol.call('foo')
2209
smart_protocol.read_response_tuple(False)
2210
self.assertEqual(True, smart_protocol.response_status)
2212
def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
2213
"""read_response_tuple raises UnknownSmartMethod if the response says
2214
the server did not recognise the request.
2217
protocol.RESPONSE_VERSION_TWO +
2219
"error\x01Generic bzr smart protocol error: bad request 'foo'\n")
2220
input = StringIO(server_bytes)
2222
client_medium = medium.SmartSimplePipesClientMedium(
2223
input, output, 'base')
2224
request = client_medium.get_request()
2225
smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
2226
smart_protocol.call('foo')
2228
errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
2229
# The request has been finished. There is no body to read, and
2230
# attempts to read one will fail.
2232
errors.ReadingCompleted, smart_protocol.read_body_bytes)
2235
class TestSmartProtocolTwoSpecifics(
2236
TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
2237
"""Tests for aspects of smart protocol version two that are unique to
2240
Thus tests involving body streams and success/failure markers belong here.
2243
client_protocol_class = protocol.SmartClientRequestProtocolTwo
2244
server_protocol_class = protocol.SmartServerRequestProtocolTwo
2247
class TestVersionOneFeaturesInProtocolThree(
2248
TestSmartProtocol, CommonSmartProtocolTestMixin):
2249
"""Tests for version one smart protocol features as implemented by version
2253
request_encoder = protocol.ProtocolThreeRequester
2254
response_decoder = protocol.ProtocolThreeDecoder
2255
# build_server_protocol_three is a function, so we can't set it as a class
2256
# attribute directly, because then Python will assume it is actually a
2257
# method. So we make server_protocol_class be a static method, rather than
2259
# "server_protocol_class = protocol.build_server_protocol_three".
2260
server_protocol_class = staticmethod(protocol.build_server_protocol_three)
2263
super(TestVersionOneFeaturesInProtocolThree, self).setUp()
2264
self.response_marker = protocol.MESSAGE_VERSION_THREE
2265
self.request_marker = protocol.MESSAGE_VERSION_THREE
2267
def test_construct_version_three_server_protocol(self):
2268
smart_protocol = protocol.ProtocolThreeDecoder(None)
2269
self.assertEqual('', smart_protocol.unused_data)
2270
self.assertEqual('', smart_protocol._in_buffer)
2271
self.assertFalse(smart_protocol._has_dispatched)
2272
# The protocol starts by expecting four bytes, a length prefix for the
2274
self.assertEqual(4, smart_protocol.next_read_size())
2277
class NoOpRequest(_mod_request.SmartServerRequest):
2280
return _mod_request.SuccessfulSmartServerResponse(())
2282
dummy_registry = {'ARG': NoOpRequest}
2285
class LoggingMessageHandler(object):
2290
def _log(self, *args):
2291
self.event_log.append(args)
2293
def headers_received(self, headers):
2294
self._log('headers', headers)
2296
def protocol_error(self, exception):
2297
self._log('protocol_error', exception)
2299
def byte_part_received(self, byte):
2300
self._log('byte', byte)
2302
def bytes_part_received(self, bytes):
2303
self._log('bytes', bytes)
2305
def structure_part_received(self, structure):
2306
self._log('structure', structure)
2308
def end_received(self):
2312
class TestProtocolThree(TestSmartProtocol):
2313
"""Tests for v3 of the server-side protocol."""
2315
request_encoder = protocol.ProtocolThreeRequester
2316
response_decoder = protocol.ProtocolThreeDecoder
2317
server_protocol_class = protocol.ProtocolThreeDecoder
2319
def test_trivial_request(self):
2320
"""Smoke test for the simplest possible v3 request: empty headers, no
2324
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2326
request_bytes = headers + end
2327
smart_protocol = self.server_protocol_class(LoggingMessageHandler())
2328
smart_protocol.accept_bytes(request_bytes)
2329
self.assertEqual(0, smart_protocol.next_read_size())
2330
self.assertEqual('', smart_protocol.unused_data)
2332
def make_protocol_expecting_message_part(self):
2333
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2334
message_handler = LoggingMessageHandler()
2335
smart_protocol = self.server_protocol_class(message_handler)
2336
smart_protocol.accept_bytes(headers)
2337
# Clear the event log
2338
del message_handler.event_log[:]
2339
return smart_protocol, message_handler.event_log
2341
def test_decode_one_byte(self):
2342
"""The protocol can decode a 'one byte' message part."""
2343
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2344
smart_protocol.accept_bytes('ox')
2345
self.assertEqual([('byte', 'x')], event_log)
2347
def test_decode_bytes(self):
2348
"""The protocol can decode a 'bytes' message part."""
2349
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2350
smart_protocol.accept_bytes(
2351
'b' # message part kind
2352
'\0\0\0\x07' # length prefix
2355
self.assertEqual([('bytes', 'payload')], event_log)
2357
def test_decode_structure(self):
2358
"""The protocol can decode a 'structure' message part."""
2359
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2360
smart_protocol.accept_bytes(
2361
's' # message part kind
2362
'\0\0\0\x07' # length prefix
2365
self.assertEqual([('structure', ['ARG'])], event_log)
2367
def test_decode_multiple_bytes(self):
2368
"""The protocol can decode a multiple 'bytes' message parts."""
2369
smart_protocol, event_log = self.make_protocol_expecting_message_part()
2370
smart_protocol.accept_bytes(
2371
'b' # message part kind
2372
'\0\0\0\x05' # length prefix
2374
'b' # message part kind
2379
[('bytes', 'first'), ('bytes', 'second')], event_log)
2382
class TestConventionalResponseHandler(tests.TestCase):
2384
def make_response_handler(self, response_bytes):
2385
from bzrlib.smart.message import ConventionalResponseHandler
2386
response_handler = ConventionalResponseHandler()
2387
protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
2388
# put decoder in desired state (waiting for message parts)
2389
protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
2391
client_medium = medium.SmartSimplePipesClientMedium(
2392
StringIO(response_bytes), output, 'base')
2393
medium_request = client_medium.get_request()
2394
medium_request.finished_writing()
2395
response_handler.setProtoAndMediumRequest(
2396
protocol_decoder, medium_request)
2397
return response_handler
2399
def test_body_stream_interrupted_by_error(self):
2400
interrupted_body_stream = (
2401
'oS' # successful response
2402
's\0\0\0\x02le' # empty args
2403
'b\0\0\0\x09chunk one' # first chunk
2404
'b\0\0\0\x09chunk two' # second chunk
2406
's\0\0\0\x0el5:error3:abce' # bencoded error
2409
response_handler = self.make_response_handler(interrupted_body_stream)
2410
stream = response_handler.read_streamed_body()
2411
self.assertEqual('chunk one', stream.next())
2412
self.assertEqual('chunk two', stream.next())
2413
exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2414
self.assertEqual(('error', 'abc'), exc.error_tuple)
2416
def test_body_stream_interrupted_by_connection_lost(self):
2417
interrupted_body_stream = (
2418
'oS' # successful response
2419
's\0\0\0\x02le' # empty args
2420
'b\0\0\xff\xffincomplete chunk')
2421
response_handler = self.make_response_handler(interrupted_body_stream)
2422
stream = response_handler.read_streamed_body()
2423
self.assertRaises(errors.ConnectionReset, stream.next)
2425
def test_read_body_bytes_interrupted_by_connection_lost(self):
2426
interrupted_body_stream = (
2427
'oS' # successful response
2428
's\0\0\0\x02le' # empty args
2429
'b\0\0\xff\xffincomplete chunk')
2430
response_handler = self.make_response_handler(interrupted_body_stream)
2432
errors.ConnectionReset, response_handler.read_body_bytes)
2435
class TestMessageHandlerErrors(tests.TestCase):
2436
"""Tests for v3 that unrecognised (but well-formed) requests/responses are
2437
still fully read off the wire, so that subsequent requests/responses on the
2438
same medium can be decoded.
2441
def test_non_conventional_request(self):
2442
"""ConventionalRequestHandler (the default message handler on the
2443
server side) will reject an unconventional message, but still consume
2444
all the bytes of that message and signal when it has done so.
2446
This is what allows a server to continue to accept requests after the
2447
client sends a completely unrecognised request.
2449
# Define an invalid request (but one that is a well-formed message).
2450
# This particular invalid request not only lacks the mandatory
2451
# verb+args tuple, it has a single-byte part, which is forbidden. In
2452
# fact it has that part twice, to trigger multiple errors.
2454
protocol.MESSAGE_VERSION_THREE + # protocol version marker
2455
'\0\0\0\x02de' + # empty headers
2456
'oX' + # a single byte part: 'X'. ConventionalRequestHandler will
2457
# error at this part.
2459
'e' # end of message
2462
to_server = StringIO(invalid_request)
2463
from_server = StringIO()
2464
transport = memory.MemoryTransport('memory:///')
2465
server = medium.SmartServerPipeStreamMedium(
2466
to_server, from_server, transport)
2467
proto = server._build_protocol()
2468
message_handler = proto.message_handler
2469
server._serve_one_request(proto)
2470
# All the bytes have been read from the medium...
2471
self.assertEqual('', to_server.read())
2472
# ...and the protocol decoder has consumed all the bytes, and has
2474
self.assertEqual('', proto.unused_data)
2475
self.assertEqual(0, proto.next_read_size())
2478
class InstrumentedRequestHandler(object):
2479
"""Test Double of SmartServerRequestHandler."""
2484
def body_chunk_received(self, chunk_bytes):
2485
self.calls.append(('body_chunk_received', chunk_bytes))
2487
def no_body_received(self):
2488
self.calls.append(('no_body_received',))
2490
def prefixed_body_received(self, body_bytes):
2491
self.calls.append(('prefixed_body_received', body_bytes))
2493
def end_received(self):
2494
self.calls.append(('end_received',))
2497
class StubRequest(object):
2499
def finished_reading(self):
2503
class TestClientDecodingProtocolThree(TestSmartProtocol):
2504
"""Tests for v3 of the client-side protocol decoding."""
2506
def make_logging_response_decoder(self):
2507
"""Make v3 response decoder using a test response handler."""
2508
response_handler = LoggingMessageHandler()
2509
decoder = protocol.ProtocolThreeDecoder(response_handler)
2510
return decoder, response_handler
2512
def make_conventional_response_decoder(self):
2513
"""Make v3 response decoder using a conventional response handler."""
2514
response_handler = message.ConventionalResponseHandler()
2515
decoder = protocol.ProtocolThreeDecoder(response_handler)
2516
response_handler.setProtoAndMediumRequest(decoder, StubRequest())
2517
return decoder, response_handler
2519
def test_trivial_response_decoding(self):
2520
"""Smoke test for the simplest possible v3 response: empty headers,
2521
status byte, empty args, no body.
2523
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2524
response_status = 'oS' # success
2525
args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
2526
end = 'e' # end marker
2527
message_bytes = headers + response_status + args + end
2528
decoder, response_handler = self.make_logging_response_decoder()
2529
decoder.accept_bytes(message_bytes)
2530
# The protocol decoder has finished, and consumed all bytes
2531
self.assertEqual(0, decoder.next_read_size())
2532
self.assertEqual('', decoder.unused_data)
2533
# The message handler has been invoked with all the parts of the
2534
# trivial response: empty headers, status byte, no args, end.
2536
[('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
2537
response_handler.event_log)
2539
def test_incomplete_message(self):
2540
"""A decoder will keep signalling that it needs more bytes via
2541
next_read_size() != 0 until it has seen a complete message, regardless
2542
which state it is in.
2544
# Define a simple response that uses all possible message parts.
2545
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2546
response_status = 'oS' # success
2547
args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
2548
body = 'b\0\0\0\x04BODY' # a body: 'BODY'
2549
end = 'e' # end marker
2550
simple_response = headers + response_status + args + body + end
2551
# Feed the request to the decoder one byte at a time.
2552
decoder, response_handler = self.make_logging_response_decoder()
2553
for byte in simple_response:
2554
self.assertNotEqual(0, decoder.next_read_size())
2555
decoder.accept_bytes(byte)
2556
# Now the response is complete
2557
self.assertEqual(0, decoder.next_read_size())
2559
def test_read_response_tuple_raises_UnknownSmartMethod(self):
2560
"""read_response_tuple raises UnknownSmartMethod if the server replied
2561
with 'UnknownMethod'.
2563
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2564
response_status = 'oE' # error flag
2565
# args: ('UnknownMethod', 'method-name')
2566
args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
2567
end = 'e' # end marker
2568
message_bytes = headers + response_status + args + end
2569
decoder, response_handler = self.make_conventional_response_decoder()
2570
decoder.accept_bytes(message_bytes)
2571
error = self.assertRaises(
2572
errors.UnknownSmartMethod, response_handler.read_response_tuple)
2573
self.assertEqual('method-name', error.verb)
2575
def test_read_response_tuple_error(self):
2576
"""If the response has an error, it is raised as an exception."""
2577
headers = '\0\0\0\x02de' # length-prefixed, bencoded empty dict
2578
response_status = 'oE' # error
2579
args = 's\0\0\0\x1al9:first arg10:second arge' # two args
2580
end = 'e' # end marker
2581
message_bytes = headers + response_status + args + end
2582
decoder, response_handler = self.make_conventional_response_decoder()
2583
decoder.accept_bytes(message_bytes)
2584
error = self.assertRaises(
2585
errors.ErrorFromSmartServer, response_handler.read_response_tuple)
2586
self.assertEqual(('first arg', 'second arg'), error.error_tuple)
2589
class TestClientEncodingProtocolThree(TestSmartProtocol):
2591
request_encoder = protocol.ProtocolThreeRequester
2592
response_decoder = protocol.ProtocolThreeDecoder
2593
server_protocol_class = protocol.ProtocolThreeDecoder
2595
def make_client_encoder_and_output(self):
2596
result = self.make_client_protocol_and_output()
2597
requester, response_handler, output = result
2598
return requester, output
2600
def test_call_smoke_test(self):
2601
"""A smoke test for ProtocolThreeRequester.call.
2603
This test checks that a particular simple invocation of call emits the
2604
correct bytes for that invocation.
2606
requester, output = self.make_client_encoder_and_output()
2607
requester.set_headers({'header name': 'header value'})
2608
requester.call('one arg')
2610
'bzr message 3 (bzr 1.6)\n' # protocol version
2611
'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2612
's\x00\x00\x00\x0bl7:one arge' # args
2616
def test_call_with_body_bytes_smoke_test(self):
2617
"""A smoke test for ProtocolThreeRequester.call_with_body_bytes.
2619
This test checks that a particular simple invocation of
2620
call_with_body_bytes emits the correct bytes for that invocation.
2622
requester, output = self.make_client_encoder_and_output()
2623
requester.set_headers({'header name': 'header value'})
2624
requester.call_with_body_bytes(('one arg',), 'body bytes')
2626
'bzr message 3 (bzr 1.6)\n' # protocol version
2627
'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2628
's\x00\x00\x00\x0bl7:one arge' # args
2629
'b' # there is a prefixed body
2630
'\x00\x00\x00\nbody bytes' # the prefixed body
2634
def test_call_writes_just_once(self):
2635
"""A bodyless request is written to the medium all at once."""
2636
medium_request = StubMediumRequest()
2637
encoder = protocol.ProtocolThreeRequester(medium_request)
2638
encoder.call('arg1', 'arg2', 'arg3')
2640
['accept_bytes', 'finished_writing'], medium_request.calls)
2642
def test_call_with_body_bytes_writes_just_once(self):
2643
"""A request with body bytes is written to the medium all at once."""
2644
medium_request = StubMediumRequest()
2645
encoder = protocol.ProtocolThreeRequester(medium_request)
2646
encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
2648
['accept_bytes', 'finished_writing'], medium_request.calls)
2651
class StubMediumRequest(object):
2652
"""A stub medium request that tracks the number of times accept_bytes is
2658
self._medium = 'dummy medium'
2660
def accept_bytes(self, bytes):
2661
self.calls.append('accept_bytes')
2663
def finished_writing(self):
2664
self.calls.append('finished_writing')
2667
class TestResponseEncodingProtocolThree(tests.TestCase):
2669
def make_response_encoder(self):
2670
out_stream = StringIO()
2671
response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
2672
return response_encoder, out_stream
2674
def test_send_error_unknown_method(self):
2675
encoder, out_stream = self.make_response_encoder()
2676
encoder.send_error(errors.UnknownSmartMethod('method name'))
2677
# Use assertEndsWith so that we don't compare the header, which varies
2678
# by bzrlib.__version__.
2679
self.assertEndsWith(
2680
out_stream.getvalue(),
2683
# tuple: 'UnknownMethod', 'method name'
2684
's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
2689
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
2690
"""Tests for buffering of responses.
2692
We want to avoid doing many small writes when one would do, to avoid
2693
unnecessary network overhead.
2698
self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2700
def assertWriteCount(self, expected_count):
2702
expected_count, len(self.writes),
2703
"Too many writes: %r" % (self.writes,))
2705
def test_send_error_writes_just_once(self):
2706
"""An error response is written to the medium all at once."""
2707
self.responder.send_error(Exception('An exception string.'))
2708
self.assertWriteCount(1)
2710
def test_send_response_writes_just_once(self):
2711
"""A normal response with no body is written to the medium all at once.
2713
response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
2714
self.responder.send_response(response)
2715
self.assertWriteCount(1)
2717
def test_send_response_with_body_writes_just_once(self):
2718
"""A normal response with a monolithic body is written to the medium
2721
response = _mod_request.SuccessfulSmartServerResponse(
2722
('arg', 'arg'), body='body bytes')
2723
self.responder.send_response(response)
2724
self.assertWriteCount(1)
2726
def test_send_response_with_body_stream_writes_once_per_chunk(self):
2727
"""A normal response with a stream body is written to the medium
2728
writes to the medium once per chunk.
2730
# Construct a response with stream with 2 chunks in it.
2731
response = _mod_request.SuccessfulSmartServerResponse(
2732
('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2733
self.responder.send_response(response)
2734
# We will write 3 times: exactly once for each chunk, plus a final
2735
# write to end the response.
2736
self.assertWriteCount(3)
2739
class TestSmartClientUnicode(tests.TestCase):
2740
"""_SmartClient tests for unicode arguments.
2742
Unicode arguments to call_with_body_bytes are not correct (remote method
2743
names, arguments, and bodies must all be expressed as byte strings), but
2744
_SmartClient should gracefully reject them, rather than getting into a
2745
broken state that prevents future correct calls from working. That is, it
2746
should be possible to issue more requests on the medium afterwards, rather
2747
than allowing one bad call to call_with_body_bytes to cause later calls to
2748
mysteriously fail with TooManyConcurrentRequests.
2751
def assertCallDoesNotBreakMedium(self, method, args, body):
2752
"""Call a medium with the given method, args and body, then assert that
2753
the medium is left in a sane state, i.e. is capable of allowing further
2756
input = StringIO("\n")
2758
client_medium = medium.SmartSimplePipesClientMedium(
2759
input, output, 'ignored base')
2760
smart_client = client._SmartClient(client_medium)
2761
self.assertRaises(TypeError,
2762
smart_client.call_with_body_bytes, method, args, body)
2763
self.assertEqual("", output.getvalue())
2764
self.assertEqual(None, client_medium._current_request)
2766
def test_call_with_body_bytes_unicode_method(self):
2767
self.assertCallDoesNotBreakMedium(u'method', ('args',), 'body')
2769
def test_call_with_body_bytes_unicode_args(self):
2770
self.assertCallDoesNotBreakMedium('method', (u'args',), 'body')
2771
self.assertCallDoesNotBreakMedium('method', ('arg1', u'arg2'), 'body')
2773
def test_call_with_body_bytes_unicode_body(self):
2774
self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
2777
class MockMedium(medium.SmartClientMedium):
2778
"""A mock medium that can be used to test _SmartClient.
2780
It can be given a series of requests to expect (and responses it should
2781
return for them). It can also be told when the client is expected to
2782
disconnect a medium. Expectations must be satisfied in the order they are
2783
given, or else an AssertionError will be raised.
2785
Typical use looks like::
2787
medium = MockMedium()
2788
medium.expect_request(...)
2789
medium.expect_request(...)
2790
medium.expect_request(...)
2794
super(MockMedium, self).__init__('dummy base')
2795
self._mock_request = _MockMediumRequest(self)
2796
self._expected_events = []
2798
def expect_request(self, request_bytes, response_bytes,
2799
allow_partial_read=False):
2800
"""Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
2802
No assumption is made about how many times accept_bytes should be
2803
called to send the request. Similarly, no assumption is made about how
2804
many times read_bytes/read_line are called by protocol code to read a
2807
request.accept_bytes('ab')
2808
request.accept_bytes('cd')
2809
request.finished_writing()
2813
request.accept_bytes('abcd')
2814
request.finished_writing()
2816
Will both satisfy ``medium.expect_request('abcd', ...)``. Thus tests
2817
using this should not break due to irrelevant changes in protocol
2820
:param allow_partial_read: if True, no assertion is raised if a
2821
response is not fully read. Setting this is useful when the client
2822
is expected to disconnect without needing to read the complete
2823
response. Default is False.
2825
self._expected_events.append(('send request', request_bytes))
2826
if allow_partial_read:
2827
self._expected_events.append(
2828
('read response (partial)', response_bytes))
2830
self._expected_events.append(('read response', response_bytes))
2832
def expect_disconnect(self):
2833
"""Expect the client to call ``medium.disconnect()``."""
2834
self._expected_events.append('disconnect')
2836
def _assertEvent(self, observed_event):
2837
"""Raise AssertionError unless observed_event matches the next expected
2840
:seealso: expect_request
2841
:seealso: expect_disconnect
2844
expected_event = self._expected_events.pop(0)
2846
raise AssertionError(
2847
'Mock medium observed event %r, but no more events expected'
2848
% (observed_event,))
2849
if expected_event[0] == 'read response (partial)':
2850
if observed_event[0] != 'read response':
2851
raise AssertionError(
2852
'Mock medium observed event %r, but expected event %r'
2853
% (observed_event, expected_event))
2854
elif observed_event != expected_event:
2855
raise AssertionError(
2856
'Mock medium observed event %r, but expected event %r'
2857
% (observed_event, expected_event))
2858
if self._expected_events:
2859
next_event = self._expected_events[0]
2860
if next_event[0].startswith('read response'):
2861
self._mock_request._response = next_event[1]
2863
def get_request(self):
2864
return self._mock_request
2866
def disconnect(self):
2867
if self._mock_request._read_bytes:
2868
self._assertEvent(('read response', self._mock_request._read_bytes))
2869
self._mock_request._read_bytes = ''
2870
self._assertEvent('disconnect')
2873
class _MockMediumRequest(object):
2874
"""A mock ClientMediumRequest used by MockMedium."""
2876
def __init__(self, mock_medium):
2877
self._medium = mock_medium
2878
self._written_bytes = ''
2879
self._read_bytes = ''
2880
self._response = None
2882
def accept_bytes(self, bytes):
2883
self._written_bytes += bytes
2885
def finished_writing(self):
2886
self._medium._assertEvent(('send request', self._written_bytes))
2887
self._written_bytes = ''
2889
def finished_reading(self):
2890
self._medium._assertEvent(('read response', self._read_bytes))
2891
self._read_bytes = ''
2893
def read_bytes(self, size):
2894
resp = self._response
2895
bytes, resp = resp[:size], resp[size:]
2896
self._response = resp
2897
self._read_bytes += bytes
2900
def read_line(self):
2901
resp = self._response
2903
line, resp = resp.split('\n', 1)
2906
line, resp = resp, ''
2907
self._response = resp
2908
self._read_bytes += line
2912
class Test_SmartClientVersionDetection(tests.TestCase):
2913
"""Tests for _SmartClient's automatic protocol version detection.
2915
On the first remote call, _SmartClient will keep retrying the request with
2916
different protocol versions until it finds one that works.
2919
def test_version_three_server(self):
2920
"""With a protocol 3 server, only one request is needed."""
2921
medium = MockMedium()
2922
smart_client = client._SmartClient(medium, headers={})
2923
message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
2924
medium.expect_request(
2926
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2927
message_start + 's\0\0\0\x13l14:response valueee')
2928
result = smart_client.call('method-name', 'arg 1', 'arg 2')
2929
# The call succeeded without raising any exceptions from the mock
2930
# medium, and the smart_client returns the response from the server.
2931
self.assertEqual(('response value',), result)
2932
self.assertEqual([], medium._expected_events)
2933
# Also, the v3 works then the server should be assumed to support RPCs
2934
# introduced in 1.6.
2935
self.assertFalse(medium._is_remote_before((1, 6)))
2937
def test_version_two_server(self):
2938
"""If the server only speaks protocol 2, the client will first try
2939
version 3, then fallback to protocol 2.
2941
Further, _SmartClient caches the detection, so future requests will all
2942
use protocol 2 immediately.
2944
medium = MockMedium()
2945
smart_client = client._SmartClient(medium, headers={})
2946
# First the client should send a v3 request, but the server will reply
2948
medium.expect_request(
2949
'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
2950
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2951
'bzr response 2\nfailed\n\n')
2952
# So then the client should disconnect to reset the connection, because
2953
# the client needs to assume the server cannot read any further
2954
# requests off the original connection.
2955
medium.expect_disconnect()
2956
# The client should then retry the original request in v2
2957
medium.expect_request(
2958
'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
2959
'bzr response 2\nsuccess\nresponse value\n')
2960
result = smart_client.call('method-name', 'arg 1', 'arg 2')
2961
# The smart_client object will return the result of the successful
2963
self.assertEqual(('response value',), result)
2965
# Now try another request, and this time the client will just use
2966
# protocol 2. (i.e. the autodetection won't be repeated)
2967
medium.expect_request(
2968
'bzr request 2\nanother-method\n',
2969
'bzr response 2\nsuccess\nanother response\n')
2970
result = smart_client.call('another-method')
2971
self.assertEqual(('another response',), result)
2972
self.assertEqual([], medium._expected_events)
2974
# Also, because v3 is not supported, the client medium should assume
2975
# that RPCs introduced in 1.6 aren't supported either.
2976
self.assertTrue(medium._is_remote_before((1, 6)))
2978
def test_unknown_version(self):
2979
"""If the server does not use any known (or at least supported)
2980
protocol version, a SmartProtocolError is raised.
2982
medium = MockMedium()
2983
smart_client = client._SmartClient(medium, headers={})
2984
unknown_protocol_bytes = 'Unknown protocol!'
2985
# The client will try v3 and v2 before eventually giving up.
2986
medium.expect_request(
2987
'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
2988
's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
2989
unknown_protocol_bytes)
2990
medium.expect_disconnect()
2991
medium.expect_request(
2992
'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
2993
unknown_protocol_bytes)
2994
medium.expect_disconnect()
2996
errors.SmartProtocolError,
2997
smart_client.call, 'method-name', 'arg 1', 'arg 2')
2998
self.assertEqual([], medium._expected_events)
3000
def test_first_response_is_error(self):
3001
"""If the server replies with an error, then the version detection
3004
This test is very similar to test_version_two_server, but catches a bug
3005
we had in the case where the first reply was an error response.
3007
medium = MockMedium()
3008
smart_client = client._SmartClient(medium, headers={})
3009
message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
3010
# Issue a request that gets an error reply in a non-default protocol
3012
medium.expect_request(
3014
's\x00\x00\x00\x10l11:method-nameee',
3015
'bzr response 2\nfailed\n\n')
3016
medium.expect_disconnect()
3017
medium.expect_request(
3018
'bzr request 2\nmethod-name\n',
3019
'bzr response 2\nfailed\nFooBarError\n')
3020
err = self.assertRaises(
3021
errors.ErrorFromSmartServer,
3022
smart_client.call, 'method-name')
3023
self.assertEqual(('FooBarError',), err.error_tuple)
3024
# Now the medium should have remembered the protocol version, so
3025
# subsequent requests will use the remembered version immediately.
3026
medium.expect_request(
3027
'bzr request 2\nmethod-name\n',
3028
'bzr response 2\nsuccess\nresponse value\n')
3029
result = smart_client.call('method-name')
3030
self.assertEqual(('response value',), result)
3031
self.assertEqual([], medium._expected_events)
3034
class Test_SmartClient(tests.TestCase):
3036
def test_call_default_headers(self):
3037
"""ProtocolThreeRequester.call by default sends a 'Software
3040
smart_client = client._SmartClient('dummy medium')
3042
bzrlib.__version__, smart_client._headers['Software version'])
3043
# XXX: need a test that smart_client._headers is passed to the request
3047
class LengthPrefixedBodyDecoder(tests.TestCase):
3049
# XXX: TODO: make accept_reading_trailer invoke translate_response or
3050
# something similar to the ProtocolBase method.
3052
def test_construct(self):
3053
decoder = protocol.LengthPrefixedBodyDecoder()
3054
self.assertFalse(decoder.finished_reading)
3055
self.assertEqual(6, decoder.next_read_size())
3056
self.assertEqual('', decoder.read_pending_data())
3057
self.assertEqual('', decoder.unused_data)
3059
def test_accept_bytes(self):
3060
decoder = protocol.LengthPrefixedBodyDecoder()
3061
decoder.accept_bytes('')
3062
self.assertFalse(decoder.finished_reading)
3063
self.assertEqual(6, decoder.next_read_size())
3064
self.assertEqual('', decoder.read_pending_data())
3065
self.assertEqual('', decoder.unused_data)
3066
decoder.accept_bytes('7')
3067
self.assertFalse(decoder.finished_reading)
3068
self.assertEqual(6, decoder.next_read_size())
3069
self.assertEqual('', decoder.read_pending_data())
3070
self.assertEqual('', decoder.unused_data)
3071
decoder.accept_bytes('\na')
3072
self.assertFalse(decoder.finished_reading)
3073
self.assertEqual(11, decoder.next_read_size())
3074
self.assertEqual('a', decoder.read_pending_data())
3075
self.assertEqual('', decoder.unused_data)
3076
decoder.accept_bytes('bcdefgd')
3077
self.assertFalse(decoder.finished_reading)
3078
self.assertEqual(4, decoder.next_read_size())
3079
self.assertEqual('bcdefg', decoder.read_pending_data())
3080
self.assertEqual('', decoder.unused_data)
3081
decoder.accept_bytes('one')
3082
self.assertFalse(decoder.finished_reading)
3083
self.assertEqual(1, decoder.next_read_size())
3084
self.assertEqual('', decoder.read_pending_data())
3085
self.assertEqual('', decoder.unused_data)
3086
decoder.accept_bytes('\nblarg')
3087
self.assertTrue(decoder.finished_reading)
3088
self.assertEqual(1, decoder.next_read_size())
3089
self.assertEqual('', decoder.read_pending_data())
3090
self.assertEqual('blarg', decoder.unused_data)
3092
def test_accept_bytes_all_at_once_with_excess(self):
3093
decoder = protocol.LengthPrefixedBodyDecoder()
3094
decoder.accept_bytes('1\nadone\nunused')
3095
self.assertTrue(decoder.finished_reading)
3096
self.assertEqual(1, decoder.next_read_size())
3097
self.assertEqual('a', decoder.read_pending_data())
3098
self.assertEqual('unused', decoder.unused_data)
3100
def test_accept_bytes_exact_end_of_body(self):
3101
decoder = protocol.LengthPrefixedBodyDecoder()
3102
decoder.accept_bytes('1\na')
3103
self.assertFalse(decoder.finished_reading)
3104
self.assertEqual(5, decoder.next_read_size())
3105
self.assertEqual('a', decoder.read_pending_data())
3106
self.assertEqual('', decoder.unused_data)
3107
decoder.accept_bytes('done\n')
3108
self.assertTrue(decoder.finished_reading)
3109
self.assertEqual(1, decoder.next_read_size())
3110
self.assertEqual('', decoder.read_pending_data())
3111
self.assertEqual('', decoder.unused_data)
3114
class TestChunkedBodyDecoder(tests.TestCase):
3115
"""Tests for ChunkedBodyDecoder.
3117
This is the body decoder used for protocol version two.
3120
def test_construct(self):
3121
decoder = protocol.ChunkedBodyDecoder()
3122
self.assertFalse(decoder.finished_reading)
3123
self.assertEqual(8, decoder.next_read_size())
3124
self.assertEqual(None, decoder.read_next_chunk())
3125
self.assertEqual('', decoder.unused_data)
3127
def test_empty_content(self):
3128
"""'chunked\nEND\n' is the complete encoding of a zero-length body.
3130
decoder = protocol.ChunkedBodyDecoder()
3131
decoder.accept_bytes('chunked\n')
3132
decoder.accept_bytes('END\n')
3133
self.assertTrue(decoder.finished_reading)
3134
self.assertEqual(None, decoder.read_next_chunk())
3135
self.assertEqual('', decoder.unused_data)
3137
def test_one_chunk(self):
3138
"""A body in a single chunk is decoded correctly."""
3139
decoder = protocol.ChunkedBodyDecoder()
3140
decoder.accept_bytes('chunked\n')
3141
chunk_length = 'f\n'
3142
chunk_content = '123456789abcdef'
3144
decoder.accept_bytes(chunk_length + chunk_content + finish)
3145
self.assertTrue(decoder.finished_reading)
3146
self.assertEqual(chunk_content, decoder.read_next_chunk())
3147
self.assertEqual('', decoder.unused_data)
3149
def test_incomplete_chunk(self):
3150
"""When there are less bytes in the chunk than declared by the length,
3151
then we haven't finished reading yet.
3153
decoder = protocol.ChunkedBodyDecoder()
3154
decoder.accept_bytes('chunked\n')
3155
chunk_length = '8\n'
3157
decoder.accept_bytes(chunk_length + three_bytes)
3158
self.assertFalse(decoder.finished_reading)
3160
5 + 4, decoder.next_read_size(),
3161
"The next_read_size hint should be the number of missing bytes in "
3162
"this chunk plus 4 (the length of the end-of-body marker: "
3164
self.assertEqual(None, decoder.read_next_chunk())
3166
def test_incomplete_length(self):
3167
"""A chunk length hasn't been read until a newline byte has been read.
3169
decoder = protocol.ChunkedBodyDecoder()
3170
decoder.accept_bytes('chunked\n')
3171
decoder.accept_bytes('9')
3173
1, decoder.next_read_size(),
3174
"The next_read_size hint should be 1, because we don't know the "
3176
decoder.accept_bytes('\n')
3178
9 + 4, decoder.next_read_size(),
3179
"The next_read_size hint should be the length of the chunk plus 4 "
3180
"(the length of the end-of-body marker: 'END\\n')")
3181
self.assertFalse(decoder.finished_reading)
3182
self.assertEqual(None, decoder.read_next_chunk())
3184
def test_two_chunks(self):
3185
"""Content from multiple chunks is concatenated."""
3186
decoder = protocol.ChunkedBodyDecoder()
3187
decoder.accept_bytes('chunked\n')
3188
chunk_one = '3\naaa'
3189
chunk_two = '5\nbbbbb'
3191
decoder.accept_bytes(chunk_one + chunk_two + finish)
3192
self.assertTrue(decoder.finished_reading)
3193
self.assertEqual('aaa', decoder.read_next_chunk())
3194
self.assertEqual('bbbbb', decoder.read_next_chunk())
3195
self.assertEqual(None, decoder.read_next_chunk())
3196
self.assertEqual('', decoder.unused_data)
3198
def test_excess_bytes(self):
3199
"""Bytes after the chunked body are reported as unused bytes."""
3200
decoder = protocol.ChunkedBodyDecoder()
3201
decoder.accept_bytes('chunked\n')
3202
chunked_body = "5\naaaaaEND\n"
3203
excess_bytes = "excess bytes"
3204
decoder.accept_bytes(chunked_body + excess_bytes)
3205
self.assertTrue(decoder.finished_reading)
3206
self.assertEqual('aaaaa', decoder.read_next_chunk())
3207
self.assertEqual(excess_bytes, decoder.unused_data)
3209
1, decoder.next_read_size(),
3210
"next_read_size hint should be 1 when finished_reading.")
3212
def test_multidigit_length(self):
3213
"""Lengths in the chunk prefixes can have multiple digits."""
3214
decoder = protocol.ChunkedBodyDecoder()
3215
decoder.accept_bytes('chunked\n')
3217
chunk_prefix = hex(length) + '\n'
3218
chunk_bytes = 'z' * length
3220
decoder.accept_bytes(chunk_prefix + chunk_bytes + finish)
3221
self.assertTrue(decoder.finished_reading)
3222
self.assertEqual(chunk_bytes, decoder.read_next_chunk())
3224
def test_byte_at_a_time(self):
3225
"""A complete body fed to the decoder one byte at a time should not
3226
confuse the decoder. That is, it should give the same result as if the
3227
bytes had been received in one batch.
3229
This test is the same as test_one_chunk apart from the way accept_bytes
3232
decoder = protocol.ChunkedBodyDecoder()
3233
decoder.accept_bytes('chunked\n')
3234
chunk_length = 'f\n'
3235
chunk_content = '123456789abcdef'
3237
for byte in (chunk_length + chunk_content + finish):
3238
decoder.accept_bytes(byte)
3239
self.assertTrue(decoder.finished_reading)
3240
self.assertEqual(chunk_content, decoder.read_next_chunk())
3241
self.assertEqual('', decoder.unused_data)
3243
def test_read_pending_data_resets(self):
3244
"""read_pending_data does not return the same bytes twice."""
3245
decoder = protocol.ChunkedBodyDecoder()
3246
decoder.accept_bytes('chunked\n')
3247
chunk_one = '3\naaa'
3248
chunk_two = '3\nbbb'
3250
decoder.accept_bytes(chunk_one)
3251
self.assertEqual('aaa', decoder.read_next_chunk())
3252
decoder.accept_bytes(chunk_two)
3253
self.assertEqual('bbb', decoder.read_next_chunk())
3254
self.assertEqual(None, decoder.read_next_chunk())
3256
def test_decode_error(self):
3257
decoder = protocol.ChunkedBodyDecoder()
3258
decoder.accept_bytes('chunked\n')
3259
chunk_one = 'b\nfirst chunk'
3260
error_signal = 'ERR\n'
3261
error_chunks = '5\npart1' + '5\npart2'
3263
decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish)
3264
self.assertTrue(decoder.finished_reading)
3265
self.assertEqual('first chunk', decoder.read_next_chunk())
3266
expected_failure = _mod_request.FailedSmartServerResponse(
3268
self.assertEqual(expected_failure, decoder.read_next_chunk())
3270
def test_bad_header(self):
3271
"""accept_bytes raises a SmartProtocolError if a chunked body does not
3272
start with the right header.
3274
decoder = protocol.ChunkedBodyDecoder()
3276
errors.SmartProtocolError, decoder.accept_bytes, 'bad header\n')
3279
class TestSuccessfulSmartServerResponse(tests.TestCase):
3281
def test_construct_no_body(self):
3282
response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'))
3283
self.assertEqual(('foo', 'bar'), response.args)
3284
self.assertEqual(None, response.body)
3286
def test_construct_with_body(self):
3287
response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'),
3289
self.assertEqual(('foo', 'bar'), response.args)
3290
self.assertEqual('bytes', response.body)
3291
# repr(response) doesn't trigger exceptions.
3294
def test_construct_with_body_stream(self):
3295
bytes_iterable = ['abc']
3296
response = _mod_request.SuccessfulSmartServerResponse(
3297
('foo', 'bar'), body_stream=bytes_iterable)
3298
self.assertEqual(('foo', 'bar'), response.args)
3299
self.assertEqual(bytes_iterable, response.body_stream)
3301
def test_construct_rejects_body_and_body_stream(self):
3302
"""'body' and 'body_stream' are mutually exclusive."""
3305
_mod_request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
3307
def test_is_successful(self):
3308
"""is_successful should return True for SuccessfulSmartServerResponse."""
3309
response = _mod_request.SuccessfulSmartServerResponse(('error',))
3310
self.assertEqual(True, response.is_successful())
3313
class TestFailedSmartServerResponse(tests.TestCase):
3315
def test_construct(self):
3316
response = _mod_request.FailedSmartServerResponse(('foo', 'bar'))
3317
self.assertEqual(('foo', 'bar'), response.args)
3318
self.assertEqual(None, response.body)
3319
response = _mod_request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
3320
self.assertEqual(('foo', 'bar'), response.args)
3321
self.assertEqual('bytes', response.body)
3322
# repr(response) doesn't trigger exceptions.
3325
def test_is_successful(self):
3326
"""is_successful should return False for FailedSmartServerResponse."""
3327
response = _mod_request.FailedSmartServerResponse(('error',))
3328
self.assertEqual(False, response.is_successful())
3331
class FakeHTTPMedium(object):
3333
self.written_request = None
3334
self._current_request = None
3335
def send_http_smart_request(self, bytes):
3336
self.written_request = bytes
3340
class HTTPTunnellingSmokeTest(tests.TestCase):
3343
super(HTTPTunnellingSmokeTest, self).setUp()
3344
# We use the VFS layer as part of HTTP tunnelling tests.
3345
self._captureVar('BZR_NO_SMART_VFS', None)
3347
def test_smart_http_medium_request_accept_bytes(self):
3348
medium = FakeHTTPMedium()
3349
request = SmartClientHTTPMediumRequest(medium)
3350
request.accept_bytes('abc')
3351
request.accept_bytes('def')
3352
self.assertEqual(None, medium.written_request)
3353
request.finished_writing()
3354
self.assertEqual('abcdef', medium.written_request)
3357
class RemoteHTTPTransportTestCase(tests.TestCase):
3359
def test_remote_path_after_clone_child(self):
3360
# If a user enters "bzr+http://host/foo", we want to sent all smart
3361
# requests for child URLs of that to the original URL. i.e., we want to
3362
# POST to "bzr+http://host/foo/.bzr/smart" and never something like
3363
# "bzr+http://host/foo/.bzr/branch/.bzr/smart". So, a cloned
3364
# RemoteHTTPTransport remembers the initial URL, and adjusts the relpaths
3365
# it sends in smart requests accordingly.
3366
base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
3367
new_transport = base_transport.clone('child_dir')
3368
self.assertEqual(base_transport._http_transport,
3369
new_transport._http_transport)
3370
self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
3373
new_transport._client.remote_path_from_transport(new_transport))
3375
def test_remote_path_unnormal_base(self):
3376
# If the transport's base isn't normalised, the _remote_path should
3377
# still be calculated correctly.
3378
base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3379
self.assertEqual('c', base_transport._remote_path('c'))
3381
def test_clone_unnormal_base(self):
3382
# If the transport's base isn't normalised, cloned transports should
3383
# still work correctly.
3384
base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3385
new_transport = base_transport.clone('c')
3386
self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
3389
new_transport._client.remote_path_from_transport(new_transport))
3392
# TODO: Client feature that does get_bundle and then installs that into a
3393
# branch; this can be used in place of the regular pull/fetch operation when
3394
# coming from a smart server.
3396
# TODO: Eventually, want to do a 'branch' command by fetching the whole
3397
# history as one big bundle. How?
3399
# The branch command does 'br_from.sprout', which tries to preserve the same
3400
# format. We don't necessarily even want that.
3402
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
3403
# operation from one branch into another. It already has some code for
3404
# pulling from a bundle, which it does by trying to see if the destination is
3405
# a bundle file. So it seems the logic for pull ought to be:
3407
# - if it's a smart server, get a bundle from there and install that
3408
# - if it's a bundle, install that
3409
# - if it's a branch, pull from there
3411
# Getting a bundle from a smart server is a bit different from reading a
3412
# bundle from a URL:
3414
# - we can reasonably remember the URL we last read from
3415
# - you can specify a revision number to pull, and we need to pass it across
3416
# to the server as a limit on what will be requested
3418
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
3419
# otherwise whether it's a bundle?) Should this be a property or method of
3420
# the transport? For the ssh protocol, we always know it's a smart server.
3421
# For http, we potentially need to probe. But if we're explicitly given
3422
# bzr+http:// then we can skip that for now.