/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Aaron Bentley
  • Date: 2008-10-16 21:37:21 UTC
  • mfrom: (0.12.63 shelf-manager)
  • mto: This revision was merged to the branch mainline in revision 3823.
  • Revision ID: aaron@aaronbentley.com-20081016213721-4evccj16q9mb05uf
Merge with shelf-manager

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Tests for smart transport"""
 
18
 
 
19
# all of this deals with byte strings so this is safe
 
20
from cStringIO import StringIO
 
21
import os
 
22
import socket
 
23
import threading
 
24
 
 
25
import bzrlib
 
26
from bzrlib import (
 
27
        bzrdir,
 
28
        errors,
 
29
        osutils,
 
30
        tests,
 
31
        urlutils,
 
32
        )
 
33
from bzrlib.smart import (
 
34
        client,
 
35
        medium,
 
36
        message,
 
37
        protocol,
 
38
        request as _mod_request,
 
39
        server,
 
40
        vfs,
 
41
)
 
42
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
 
43
from bzrlib.transport import (
 
44
        get_transport,
 
45
        local,
 
46
        memory,
 
47
        remote,
 
48
        )
 
49
from bzrlib.transport.http import SmartClientHTTPMediumRequest
 
50
 
 
51
 
 
52
class StringIOSSHVendor(object):
 
53
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
 
54
 
 
55
    def __init__(self, read_from, write_to):
 
56
        self.read_from = read_from
 
57
        self.write_to = write_to
 
58
        self.calls = []
 
59
 
 
60
    def connect_ssh(self, username, password, host, port, command):
 
61
        self.calls.append(('connect_ssh', username, password, host, port,
 
62
            command))
 
63
        return StringIOSSHConnection(self)
 
64
 
 
65
 
 
66
class StringIOSSHConnection(object):
 
67
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
 
68
 
 
69
    def __init__(self, vendor):
 
70
        self.vendor = vendor
 
71
    
 
72
    def close(self):
 
73
        self.vendor.calls.append(('close', ))
 
74
        
 
75
    def get_filelike_channels(self):
 
76
        return self.vendor.read_from, self.vendor.write_to
 
77
 
 
78
 
 
79
class _InvalidHostnameFeature(tests.Feature):
 
80
    """Does 'non_existent.invalid' fail to resolve?
 
81
    
 
82
    RFC 2606 states that .invalid is reserved for invalid domain names, and
 
83
    also underscores are not a valid character in domain names.  Despite this,
 
84
    it's possible a badly misconfigured name server might decide to always
 
85
    return an address for any name, so this feature allows us to distinguish a
 
86
    broken system from a broken test.
 
87
    """
 
88
 
 
89
    def _probe(self):
 
90
        try:
 
91
            socket.gethostbyname('non_existent.invalid')
 
92
        except socket.gaierror:
 
93
            # The host name failed to resolve.  Good.
 
94
            return True
 
95
        else:
 
96
            return False
 
97
 
 
98
    def feature_name(self):
 
99
        return 'invalid hostname'
 
100
 
 
101
InvalidHostnameFeature = _InvalidHostnameFeature()
 
102
 
 
103
 
 
104
class SmartClientMediumTests(tests.TestCase):
 
105
    """Tests for SmartClientMedium.
 
106
 
 
107
    We should create a test scenario for this: we need a server module that
 
108
    construct the test-servers (like make_loopsocket_and_medium), and the list
 
109
    of SmartClientMedium classes to test.
 
110
    """
 
111
 
 
112
    def make_loopsocket_and_medium(self):
 
113
        """Create a loopback socket for testing, and a medium aimed at it."""
 
114
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
115
        sock.bind(('127.0.0.1', 0))
 
116
        sock.listen(1)
 
117
        port = sock.getsockname()[1]
 
118
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port, 'base')
 
119
        return sock, client_medium
 
120
 
 
121
    def receive_bytes_on_server(self, sock, bytes):
 
122
        """Accept a connection on sock and read 3 bytes.
 
123
 
 
124
        The bytes are appended to the list bytes.
 
125
 
 
126
        :return: a Thread which is running to do the accept and recv.
 
127
        """
 
128
        def _receive_bytes_on_server():
 
129
            connection, address = sock.accept()
 
130
            bytes.append(osutils.recv_all(connection, 3))
 
131
            connection.close()
 
132
        t = threading.Thread(target=_receive_bytes_on_server)
 
133
        t.start()
 
134
        return t
 
135
    
 
136
    def test_construct_smart_simple_pipes_client_medium(self):
 
137
        # the SimplePipes client medium takes two pipes:
 
138
        # readable pipe, writeable pipe.
 
139
        # Constructing one should just save these and do nothing.
 
140
        # We test this by passing in None.
 
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
 
142
        
 
143
    def test_simple_pipes_client_request_type(self):
 
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
 
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
 
146
        request = client_medium.get_request()
 
147
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
 
148
 
 
149
    def test_simple_pipes_client_get_concurrent_requests(self):
 
150
        # the simple_pipes client does not support pipelined requests:
 
151
        # but it does support serial requests: we construct one after 
 
152
        # another is finished. This is a smoke test testing the integration
 
153
        # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
 
154
        # classes - as the sibling classes share this logic, they do not have
 
155
        # explicit tests for this.
 
156
        output = StringIO()
 
157
        client_medium = medium.SmartSimplePipesClientMedium(
 
158
            None, output, 'base')
 
159
        request = client_medium.get_request()
 
160
        request.finished_writing()
 
161
        request.finished_reading()
 
162
        request2 = client_medium.get_request()
 
163
        request2.finished_writing()
 
164
        request2.finished_reading()
 
165
 
 
166
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
 
167
        # accept_bytes writes to the writeable pipe.
 
168
        output = StringIO()
 
169
        client_medium = medium.SmartSimplePipesClientMedium(
 
170
            None, output, 'base')
 
171
        client_medium._accept_bytes('abc')
 
172
        self.assertEqual('abc', output.getvalue())
 
173
    
 
174
    def test_simple_pipes_client_disconnect_does_nothing(self):
 
175
        # calling disconnect does nothing.
 
176
        input = StringIO()
 
177
        output = StringIO()
 
178
        client_medium = medium.SmartSimplePipesClientMedium(
 
179
            input, output, 'base')
 
180
        # send some bytes to ensure disconnecting after activity still does not
 
181
        # close.
 
182
        client_medium._accept_bytes('abc')
 
183
        client_medium.disconnect()
 
184
        self.assertFalse(input.closed)
 
185
        self.assertFalse(output.closed)
 
186
 
 
187
    def test_simple_pipes_client_accept_bytes_after_disconnect(self):
 
188
        # calling disconnect on the client does not alter the pipe that
 
189
        # accept_bytes writes to.
 
190
        input = StringIO()
 
191
        output = StringIO()
 
192
        client_medium = medium.SmartSimplePipesClientMedium(
 
193
            input, output, 'base')
 
194
        client_medium._accept_bytes('abc')
 
195
        client_medium.disconnect()
 
196
        client_medium._accept_bytes('abc')
 
197
        self.assertFalse(input.closed)
 
198
        self.assertFalse(output.closed)
 
199
        self.assertEqual('abcabc', output.getvalue())
 
200
    
 
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
 
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
 
203
        # does nothing.
 
204
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
 
205
        client_medium.disconnect()
 
206
 
 
207
    def test_simple_pipes_client_can_always_read(self):
 
208
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
 
209
        # always tries to read from the underlying pipe.
 
210
        input = StringIO('abcdef')
 
211
        client_medium = medium.SmartSimplePipesClientMedium(input, None, 'base')
 
212
        self.assertEqual('abc', client_medium.read_bytes(3))
 
213
        client_medium.disconnect()
 
214
        self.assertEqual('def', client_medium.read_bytes(3))
 
215
        
 
216
    def test_simple_pipes_client_supports__flush(self):
 
217
        # invoking _flush on a SimplePipesClient should flush the output 
 
218
        # pipe. We test this by creating an output pipe that records
 
219
        # flush calls made to it.
 
220
        from StringIO import StringIO # get regular StringIO
 
221
        input = StringIO()
 
222
        output = StringIO()
 
223
        flush_calls = []
 
224
        def logging_flush(): flush_calls.append('flush')
 
225
        output.flush = logging_flush
 
226
        client_medium = medium.SmartSimplePipesClientMedium(
 
227
            input, output, 'base')
 
228
        # this call is here to ensure we only flush once, not on every
 
229
        # _accept_bytes call.
 
230
        client_medium._accept_bytes('abc')
 
231
        client_medium._flush()
 
232
        client_medium.disconnect()
 
233
        self.assertEqual(['flush'], flush_calls)
 
234
 
 
235
    def test_construct_smart_ssh_client_medium(self):
 
236
        # the SSH client medium takes:
 
237
        # host, port, username, password, vendor
 
238
        # Constructing one should just save these and do nothing.
 
239
        # we test this by creating a empty bound socket and constructing
 
240
        # a medium.
 
241
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
242
        sock.bind(('127.0.0.1', 0))
 
243
        unopened_port = sock.getsockname()[1]
 
244
        # having vendor be invalid means that if it tries to connect via the
 
245
        # vendor it will blow up.
 
246
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
247
            username=None, password=None, base='base', vendor="not a vendor",
 
248
            bzr_remote_path='bzr')
 
249
        sock.close()
 
250
 
 
251
    def test_ssh_client_connects_on_first_use(self):
 
252
        # The only thing that initiates a connection from the medium is giving
 
253
        # it bytes.
 
254
        output = StringIO()
 
255
        vendor = StringIOSSHVendor(StringIO(), output)
 
256
        client_medium = medium.SmartSSHClientMedium(
 
257
            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
 
258
            'bzr')
 
259
        client_medium._accept_bytes('abc')
 
260
        self.assertEqual('abc', output.getvalue())
 
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
 
262
            'a hostname', 'a port',
 
263
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
264
            vendor.calls)
 
265
    
 
266
    def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
 
267
        # The only thing that initiates a connection from the medium is giving
 
268
        # it bytes.
 
269
        output = StringIO()
 
270
        vendor = StringIOSSHVendor(StringIO(), output)
 
271
        orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
 
272
        def cleanup_environ():
 
273
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
 
274
        self.addCleanup(cleanup_environ)
 
275
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
 
276
        client_medium = self.callDeprecated(
 
277
            ['bzr_remote_path is required as of bzr 0.92'],
 
278
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
 
279
            'a password', 'base', vendor)
 
280
        client_medium._accept_bytes('abc')
 
281
        self.assertEqual('abc', output.getvalue())
 
282
        self.assertEqual([('connect_ssh', 'a username', 'a password',
 
283
            'a hostname', 'a port',
 
284
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
285
            vendor.calls)
 
286
    
 
287
    def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
 
288
        # The only thing that initiates a connection from the medium is giving
 
289
        # it bytes.
 
290
        output = StringIO()
 
291
        vendor = StringIOSSHVendor(StringIO(), output)
 
292
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
 
293
            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
 
294
        client_medium._accept_bytes('abc')
 
295
        self.assertEqual('abc', output.getvalue())
 
296
        self.assertEqual([('connect_ssh', 'a username', 'a password',
 
297
            'a hostname', 'a port',
 
298
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
299
            vendor.calls)
 
300
 
 
301
    def test_ssh_client_disconnect_does_so(self):
 
302
        # calling disconnect should disconnect both the read_from and write_to
 
303
        # file-like object it from the ssh connection.
 
304
        input = StringIO()
 
305
        output = StringIO()
 
306
        vendor = StringIOSSHVendor(input, output)
 
307
        client_medium = medium.SmartSSHClientMedium(
 
308
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
309
        client_medium._accept_bytes('abc')
 
310
        client_medium.disconnect()
 
311
        self.assertTrue(input.closed)
 
312
        self.assertTrue(output.closed)
 
313
        self.assertEqual([
 
314
            ('connect_ssh', None, None, 'a hostname', None,
 
315
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
316
            ('close', ),
 
317
            ],
 
318
            vendor.calls)
 
319
 
 
320
    def test_ssh_client_disconnect_allows_reconnection(self):
 
321
        # calling disconnect on the client terminates the connection, but should
 
322
        # not prevent additional connections occuring.
 
323
        # we test this by initiating a second connection after doing a
 
324
        # disconnect.
 
325
        input = StringIO()
 
326
        output = StringIO()
 
327
        vendor = StringIOSSHVendor(input, output)
 
328
        client_medium = medium.SmartSSHClientMedium(
 
329
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
330
        client_medium._accept_bytes('abc')
 
331
        client_medium.disconnect()
 
332
        # the disconnect has closed output, so we need a new output for the
 
333
        # new connection to write to.
 
334
        input2 = StringIO()
 
335
        output2 = StringIO()
 
336
        vendor.read_from = input2
 
337
        vendor.write_to = output2
 
338
        client_medium._accept_bytes('abc')
 
339
        client_medium.disconnect()
 
340
        self.assertTrue(input.closed)
 
341
        self.assertTrue(output.closed)
 
342
        self.assertTrue(input2.closed)
 
343
        self.assertTrue(output2.closed)
 
344
        self.assertEqual([
 
345
            ('connect_ssh', None, None, 'a hostname', None,
 
346
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
347
            ('close', ),
 
348
            ('connect_ssh', None, None, 'a hostname', None,
 
349
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
350
            ('close', ),
 
351
            ],
 
352
            vendor.calls)
 
353
    
 
354
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
 
355
        # Doing a disconnect on a new (and thus unconnected) SSH medium
 
356
        # does not fail.  It's ok to disconnect an unconnected medium.
 
357
        client_medium = medium.SmartSSHClientMedium(
 
358
            None, base='base', bzr_remote_path='bzr')
 
359
        client_medium.disconnect()
 
360
 
 
361
    def test_ssh_client_raises_on_read_when_not_connected(self):
 
362
        # Doing a read on a new (and thus unconnected) SSH medium raises
 
363
        # MediumNotConnected.
 
364
        client_medium = medium.SmartSSHClientMedium(
 
365
            None, base='base', bzr_remote_path='bzr')
 
366
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
 
367
                          0)
 
368
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
 
369
                          1)
 
370
 
 
371
    def test_ssh_client_supports__flush(self):
 
372
        # invoking _flush on a SSHClientMedium should flush the output 
 
373
        # pipe. We test this by creating an output pipe that records
 
374
        # flush calls made to it.
 
375
        from StringIO import StringIO # get regular StringIO
 
376
        input = StringIO()
 
377
        output = StringIO()
 
378
        flush_calls = []
 
379
        def logging_flush(): flush_calls.append('flush')
 
380
        output.flush = logging_flush
 
381
        vendor = StringIOSSHVendor(input, output)
 
382
        client_medium = medium.SmartSSHClientMedium(
 
383
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
 
384
        # this call is here to ensure we only flush once, not on every
 
385
        # _accept_bytes call.
 
386
        client_medium._accept_bytes('abc')
 
387
        client_medium._flush()
 
388
        client_medium.disconnect()
 
389
        self.assertEqual(['flush'], flush_calls)
 
390
        
 
391
    def test_construct_smart_tcp_client_medium(self):
 
392
        # the TCP client medium takes a host and a port.  Constructing it won't
 
393
        # connect to anything.
 
394
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
395
        sock.bind(('127.0.0.1', 0))
 
396
        unopened_port = sock.getsockname()[1]
 
397
        client_medium = medium.SmartTCPClientMedium(
 
398
            '127.0.0.1', unopened_port, 'base')
 
399
        sock.close()
 
400
 
 
401
    def test_tcp_client_connects_on_first_use(self):
 
402
        # The only thing that initiates a connection from the medium is giving
 
403
        # it bytes.
 
404
        sock, medium = self.make_loopsocket_and_medium()
 
405
        bytes = []
 
406
        t = self.receive_bytes_on_server(sock, bytes)
 
407
        medium.accept_bytes('abc')
 
408
        t.join()
 
409
        sock.close()
 
410
        self.assertEqual(['abc'], bytes)
 
411
    
 
412
    def test_tcp_client_disconnect_does_so(self):
 
413
        # calling disconnect on the client terminates the connection.
 
414
        # we test this by forcing a short read during a socket.MSG_WAITALL
 
415
        # call: write 2 bytes, try to read 3, and then the client disconnects.
 
416
        sock, medium = self.make_loopsocket_and_medium()
 
417
        bytes = []
 
418
        t = self.receive_bytes_on_server(sock, bytes)
 
419
        medium.accept_bytes('ab')
 
420
        medium.disconnect()
 
421
        t.join()
 
422
        sock.close()
 
423
        self.assertEqual(['ab'], bytes)
 
424
        # now disconnect again: this should not do anything, if disconnection
 
425
        # really did disconnect.
 
426
        medium.disconnect()
 
427
 
 
428
    
 
429
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
 
430
        # Doing a disconnect on a new (and thus unconnected) TCP medium
 
431
        # does not fail.  It's ok to disconnect an unconnected medium.
 
432
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
433
        client_medium.disconnect()
 
434
 
 
435
    def test_tcp_client_raises_on_read_when_not_connected(self):
 
436
        # Doing a read on a new (and thus unconnected) TCP medium raises
 
437
        # MediumNotConnected.
 
438
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
439
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
440
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
 
441
 
 
442
    def test_tcp_client_supports__flush(self):
 
443
        # invoking _flush on a TCPClientMedium should do something useful.
 
444
        # RBC 20060922 not sure how to test/tell in this case.
 
445
        sock, medium = self.make_loopsocket_and_medium()
 
446
        bytes = []
 
447
        t = self.receive_bytes_on_server(sock, bytes)
 
448
        # try with nothing buffered
 
449
        medium._flush()
 
450
        medium._accept_bytes('ab')
 
451
        # and with something sent.
 
452
        medium._flush()
 
453
        medium.disconnect()
 
454
        t.join()
 
455
        sock.close()
 
456
        self.assertEqual(['ab'], bytes)
 
457
        # now disconnect again : this should not do anything, if disconnection
 
458
        # really did disconnect.
 
459
        medium.disconnect()
 
460
 
 
461
    def test_tcp_client_host_unknown_connection_error(self):
 
462
        self.requireFeature(InvalidHostnameFeature)
 
463
        client_medium = medium.SmartTCPClientMedium(
 
464
            'non_existent.invalid', 4155, 'base')
 
465
        self.assertRaises(
 
466
            errors.ConnectionError, client_medium._ensure_connection)
 
467
 
 
468
 
 
469
class TestSmartClientStreamMediumRequest(tests.TestCase):
 
470
    """Tests the for SmartClientStreamMediumRequest.
 
471
    
 
472
    SmartClientStreamMediumRequest is a helper for the three stream based 
 
473
    mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
 
474
    those three mediums implement the interface it expects.
 
475
    """
 
476
 
 
477
    def test_accept_bytes_after_finished_writing_errors(self):
 
478
        # calling accept_bytes after calling finished_writing raises 
 
479
        # WritingCompleted to prevent bad assumptions on stream environments
 
480
        # breaking the needs of message-based environments.
 
481
        output = StringIO()
 
482
        client_medium = medium.SmartSimplePipesClientMedium(
 
483
            None, output, 'base')
 
484
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
485
        request.finished_writing()
 
486
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
 
487
 
 
488
    def test_accept_bytes(self):
 
489
        # accept bytes should invoke _accept_bytes on the stream medium.
 
490
        # we test this by using the SimplePipes medium - the most trivial one
 
491
        # and checking that the pipes get the data.
 
492
        input = StringIO()
 
493
        output = StringIO()
 
494
        client_medium = medium.SmartSimplePipesClientMedium(
 
495
            input, output, 'base')
 
496
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
497
        request.accept_bytes('123')
 
498
        request.finished_writing()
 
499
        request.finished_reading()
 
500
        self.assertEqual('', input.getvalue())
 
501
        self.assertEqual('123', output.getvalue())
 
502
 
 
503
    def test_construct_sets_stream_request(self):
 
504
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
 
505
        # the current request to the new SmartClientStreamMediumRequest
 
506
        output = StringIO()
 
507
        client_medium = medium.SmartSimplePipesClientMedium(
 
508
            None, output, 'base')
 
509
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
510
        self.assertIs(client_medium._current_request, request)
 
511
 
 
512
    def test_construct_while_another_request_active_throws(self):
 
513
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
 
514
        # a non-None _current_request raises TooManyConcurrentRequests.
 
515
        output = StringIO()
 
516
        client_medium = medium.SmartSimplePipesClientMedium(
 
517
            None, output, 'base')
 
518
        client_medium._current_request = "a"
 
519
        self.assertRaises(errors.TooManyConcurrentRequests,
 
520
            medium.SmartClientStreamMediumRequest, client_medium)
 
521
 
 
522
    def test_finished_read_clears_current_request(self):
 
523
        # calling finished_reading clears the current request from the requests
 
524
        # medium
 
525
        output = StringIO()
 
526
        client_medium = medium.SmartSimplePipesClientMedium(
 
527
            None, output, 'base')
 
528
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
529
        request.finished_writing()
 
530
        request.finished_reading()
 
531
        self.assertEqual(None, client_medium._current_request)
 
532
 
 
533
    def test_finished_read_before_finished_write_errors(self):
 
534
        # calling finished_reading before calling finished_writing triggers a
 
535
        # WritingNotComplete error.
 
536
        client_medium = medium.SmartSimplePipesClientMedium(
 
537
            None, None, 'base')
 
538
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
539
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
 
540
        
 
541
    def test_read_bytes(self):
 
542
        # read bytes should invoke _read_bytes on the stream medium.
 
543
        # we test this by using the SimplePipes medium - the most trivial one
 
544
        # and checking that the data is supplied. Its possible that a 
 
545
        # faulty implementation could poke at the pipe variables them selves,
 
546
        # but we trust that this will be caught as it will break the integration
 
547
        # smoke tests.
 
548
        input = StringIO('321')
 
549
        output = StringIO()
 
550
        client_medium = medium.SmartSimplePipesClientMedium(
 
551
            input, output, 'base')
 
552
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
553
        request.finished_writing()
 
554
        self.assertEqual('321', request.read_bytes(3))
 
555
        request.finished_reading()
 
556
        self.assertEqual('', input.read())
 
557
        self.assertEqual('', output.getvalue())
 
558
 
 
559
    def test_read_bytes_before_finished_write_errors(self):
 
560
        # calling read_bytes before calling finished_writing triggers a
 
561
        # WritingNotComplete error because the Smart protocol is designed to be
 
562
        # compatible with strict message based protocols like HTTP where the
 
563
        # request cannot be submitted until the writing has completed.
 
564
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
 
565
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
566
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
 
567
 
 
568
    def test_read_bytes_after_finished_reading_errors(self):
 
569
        # calling read_bytes after calling finished_reading raises 
 
570
        # ReadingCompleted to prevent bad assumptions on stream environments
 
571
        # breaking the needs of message-based environments.
 
572
        output = StringIO()
 
573
        client_medium = medium.SmartSimplePipesClientMedium(
 
574
            None, output, 'base')
 
575
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
576
        request.finished_writing()
 
577
        request.finished_reading()
 
578
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
 
579
 
 
580
 
 
581
class RemoteTransportTests(TestCaseWithSmartMedium):
 
582
 
 
583
    def test_plausible_url(self):
 
584
        self.assert_(self.get_url().startswith('bzr://'))
 
585
 
 
586
    def test_probe_transport(self):
 
587
        t = self.get_transport()
 
588
        self.assertIsInstance(t, remote.RemoteTransport)
 
589
 
 
590
    def test_get_medium_from_transport(self):
 
591
        """Remote transport has a medium always, which it can return."""
 
592
        t = self.get_transport()
 
593
        client_medium = t.get_smart_medium()
 
594
        self.assertIsInstance(client_medium, medium.SmartClientMedium)
 
595
 
 
596
 
 
597
class ErrorRaisingProtocol(object):
 
598
 
 
599
    def __init__(self, exception):
 
600
        self.exception = exception
 
601
 
 
602
    def next_read_size(self):
 
603
        raise self.exception
 
604
 
 
605
 
 
606
class SampleRequest(object):
 
607
    
 
608
    def __init__(self, expected_bytes):
 
609
        self.accepted_bytes = ''
 
610
        self._finished_reading = False
 
611
        self.expected_bytes = expected_bytes
 
612
        self.unused_data = ''
 
613
 
 
614
    def accept_bytes(self, bytes):
 
615
        self.accepted_bytes += bytes
 
616
        if self.accepted_bytes.startswith(self.expected_bytes):
 
617
            self._finished_reading = True
 
618
            self.unused_data = self.accepted_bytes[len(self.expected_bytes):]
 
619
 
 
620
    def next_read_size(self):
 
621
        if self._finished_reading:
 
622
            return 0
 
623
        else:
 
624
            return 1
 
625
 
 
626
 
 
627
class TestSmartServerStreamMedium(tests.TestCase):
 
628
 
 
629
    def setUp(self):
 
630
        super(TestSmartServerStreamMedium, self).setUp()
 
631
        self._captureVar('BZR_NO_SMART_VFS', None)
 
632
 
 
633
    def portable_socket_pair(self):
 
634
        """Return a pair of TCP sockets connected to each other.
 
635
        
 
636
        Unlike socket.socketpair, this should work on Windows.
 
637
        """
 
638
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
639
        listen_sock.bind(('127.0.0.1', 0))
 
640
        listen_sock.listen(1)
 
641
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
642
        client_sock.connect(listen_sock.getsockname())
 
643
        server_sock, addr = listen_sock.accept()
 
644
        listen_sock.close()
 
645
        return server_sock, client_sock
 
646
    
 
647
    def test_smart_query_version(self):
 
648
        """Feed a canned query version to a server"""
 
649
        # wire-to-wire, using the whole stack
 
650
        to_server = StringIO('hello\n')
 
651
        from_server = StringIO()
 
652
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
 
653
        server = medium.SmartServerPipeStreamMedium(
 
654
            to_server, from_server, transport)
 
655
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
 
656
                from_server.write)
 
657
        server._serve_one_request(smart_protocol)
 
658
        self.assertEqual('ok\0012\n',
 
659
                         from_server.getvalue())
 
660
 
 
661
    def test_response_to_canned_get(self):
 
662
        transport = memory.MemoryTransport('memory:///')
 
663
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
 
664
        to_server = StringIO('get\001./testfile\n')
 
665
        from_server = StringIO()
 
666
        server = medium.SmartServerPipeStreamMedium(
 
667
            to_server, from_server, transport)
 
668
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
 
669
                from_server.write)
 
670
        server._serve_one_request(smart_protocol)
 
671
        self.assertEqual('ok\n'
 
672
                         '17\n'
 
673
                         'contents\nof\nfile\n'
 
674
                         'done\n',
 
675
                         from_server.getvalue())
 
676
 
 
677
    def test_response_to_canned_get_of_utf8(self):
 
678
        # wire-to-wire, using the whole stack, with a UTF-8 filename.
 
679
        transport = memory.MemoryTransport('memory:///')
 
680
        utf8_filename = u'testfile\N{INTERROBANG}'.encode('utf-8')
 
681
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
 
682
        to_server = StringIO('get\001' + utf8_filename + '\n')
 
683
        from_server = StringIO()
 
684
        server = medium.SmartServerPipeStreamMedium(
 
685
            to_server, from_server, transport)
 
686
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
 
687
                from_server.write)
 
688
        server._serve_one_request(smart_protocol)
 
689
        self.assertEqual('ok\n'
 
690
                         '17\n'
 
691
                         'contents\nof\nfile\n'
 
692
                         'done\n',
 
693
                         from_server.getvalue())
 
694
 
 
695
    def test_pipe_like_stream_with_bulk_data(self):
 
696
        sample_request_bytes = 'command\n9\nbulk datadone\n'
 
697
        to_server = StringIO(sample_request_bytes)
 
698
        from_server = StringIO()
 
699
        server = medium.SmartServerPipeStreamMedium(
 
700
            to_server, from_server, None)
 
701
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
702
        server._serve_one_request(sample_protocol)
 
703
        self.assertEqual('', from_server.getvalue())
 
704
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
 
705
        self.assertFalse(server.finished)
 
706
 
 
707
    def test_socket_stream_with_bulk_data(self):
 
708
        sample_request_bytes = 'command\n9\nbulk datadone\n'
 
709
        server_sock, client_sock = self.portable_socket_pair()
 
710
        server = medium.SmartServerSocketStreamMedium(
 
711
            server_sock, None)
 
712
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
713
        client_sock.sendall(sample_request_bytes)
 
714
        server._serve_one_request(sample_protocol)
 
715
        server_sock.close()
 
716
        self.assertEqual('', client_sock.recv(1))
 
717
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
 
718
        self.assertFalse(server.finished)
 
719
 
 
720
    def test_pipe_like_stream_shutdown_detection(self):
 
721
        to_server = StringIO('')
 
722
        from_server = StringIO()
 
723
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
724
        server._serve_one_request(SampleRequest('x'))
 
725
        self.assertTrue(server.finished)
 
726
        
 
727
    def test_socket_stream_shutdown_detection(self):
 
728
        server_sock, client_sock = self.portable_socket_pair()
 
729
        client_sock.close()
 
730
        server = medium.SmartServerSocketStreamMedium(
 
731
            server_sock, None)
 
732
        server._serve_one_request(SampleRequest('x'))
 
733
        self.assertTrue(server.finished)
 
734
        
 
735
    def test_socket_stream_incomplete_request(self):
 
736
        """The medium should still construct the right protocol version even if
 
737
        the initial read only reads part of the request.
 
738
 
 
739
        Specifically, it should correctly read the protocol version line even
 
740
        if the partial read doesn't end in a newline.  An older, naive
 
741
        implementation of _get_line in the server used to have a bug in that
 
742
        case.
 
743
        """
 
744
        incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
 
745
        rest_of_request_bytes = 'lo\n'
 
746
        expected_response = (
 
747
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
748
        server_sock, client_sock = self.portable_socket_pair()
 
749
        server = medium.SmartServerSocketStreamMedium(
 
750
            server_sock, None)
 
751
        client_sock.sendall(incomplete_request_bytes)
 
752
        server_protocol = server._build_protocol()
 
753
        client_sock.sendall(rest_of_request_bytes)
 
754
        server._serve_one_request(server_protocol)
 
755
        server_sock.close()
 
756
        self.assertEqual(expected_response, client_sock.recv(50),
 
757
                         "Not a version 2 response to 'hello' request.")
 
758
        self.assertEqual('', client_sock.recv(1))
 
759
 
 
760
    def test_pipe_stream_incomplete_request(self):
 
761
        """The medium should still construct the right protocol version even if
 
762
        the initial read only reads part of the request.
 
763
 
 
764
        Specifically, it should correctly read the protocol version line even
 
765
        if the partial read doesn't end in a newline.  An older, naive
 
766
        implementation of _get_line in the server used to have a bug in that
 
767
        case.
 
768
        """
 
769
        incomplete_request_bytes = protocol.REQUEST_VERSION_TWO + 'hel'
 
770
        rest_of_request_bytes = 'lo\n'
 
771
        expected_response = (
 
772
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
773
        # Make a pair of pipes, to and from the server
 
774
        to_server, to_server_w = os.pipe()
 
775
        from_server_r, from_server = os.pipe()
 
776
        to_server = os.fdopen(to_server, 'r', 0)
 
777
        to_server_w = os.fdopen(to_server_w, 'w', 0)
 
778
        from_server_r = os.fdopen(from_server_r, 'r', 0)
 
779
        from_server = os.fdopen(from_server, 'w', 0)
 
780
        server = medium.SmartServerPipeStreamMedium(
 
781
            to_server, from_server, None)
 
782
        # Like test_socket_stream_incomplete_request, write an incomplete
 
783
        # request (that does not end in '\n') and build a protocol from it.
 
784
        to_server_w.write(incomplete_request_bytes)
 
785
        server_protocol = server._build_protocol()
 
786
        # Send the rest of the request, and finish serving it.
 
787
        to_server_w.write(rest_of_request_bytes)
 
788
        server._serve_one_request(server_protocol)
 
789
        to_server_w.close()
 
790
        from_server.close()
 
791
        self.assertEqual(expected_response, from_server_r.read(),
 
792
                         "Not a version 2 response to 'hello' request.")
 
793
        self.assertEqual('', from_server_r.read(1))
 
794
        from_server_r.close()
 
795
        to_server.close()
 
796
 
 
797
    def test_pipe_like_stream_with_two_requests(self):
 
798
        # If two requests are read in one go, then two calls to
 
799
        # _serve_one_request should still process both of them as if they had
 
800
        # been received seperately.
 
801
        sample_request_bytes = 'command\n'
 
802
        to_server = StringIO(sample_request_bytes * 2)
 
803
        from_server = StringIO()
 
804
        server = medium.SmartServerPipeStreamMedium(
 
805
            to_server, from_server, None)
 
806
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
807
        server._serve_one_request(first_protocol)
 
808
        self.assertEqual(0, first_protocol.next_read_size())
 
809
        self.assertEqual('', from_server.getvalue())
 
810
        self.assertFalse(server.finished)
 
811
        # Make a new protocol, call _serve_one_request with it to collect the
 
812
        # second request.
 
813
        second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
814
        server._serve_one_request(second_protocol)
 
815
        self.assertEqual('', from_server.getvalue())
 
816
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
 
817
        self.assertFalse(server.finished)
 
818
        
 
819
    def test_socket_stream_with_two_requests(self):
 
820
        # If two requests are read in one go, then two calls to
 
821
        # _serve_one_request should still process both of them as if they had
 
822
        # been received seperately.
 
823
        sample_request_bytes = 'command\n'
 
824
        server_sock, client_sock = self.portable_socket_pair()
 
825
        server = medium.SmartServerSocketStreamMedium(
 
826
            server_sock, None)
 
827
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
828
        # Put two whole requests on the wire.
 
829
        client_sock.sendall(sample_request_bytes * 2)
 
830
        server._serve_one_request(first_protocol)
 
831
        self.assertEqual(0, first_protocol.next_read_size())
 
832
        self.assertFalse(server.finished)
 
833
        # Make a new protocol, call _serve_one_request with it to collect the
 
834
        # second request.
 
835
        second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
836
        stream_still_open = server._serve_one_request(second_protocol)
 
837
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
 
838
        self.assertFalse(server.finished)
 
839
        server_sock.close()
 
840
        self.assertEqual('', client_sock.recv(1))
 
841
 
 
842
    def test_pipe_like_stream_error_handling(self):
 
843
        # Use plain python StringIO so we can monkey-patch the close method to
 
844
        # not discard the contents.
 
845
        from StringIO import StringIO
 
846
        to_server = StringIO('')
 
847
        from_server = StringIO()
 
848
        self.closed = False
 
849
        def close():
 
850
            self.closed = True
 
851
        from_server.close = close
 
852
        server = medium.SmartServerPipeStreamMedium(
 
853
            to_server, from_server, None)
 
854
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
 
855
        server._serve_one_request(fake_protocol)
 
856
        self.assertEqual('', from_server.getvalue())
 
857
        self.assertTrue(self.closed)
 
858
        self.assertTrue(server.finished)
 
859
        
 
860
    def test_socket_stream_error_handling(self):
 
861
        server_sock, client_sock = self.portable_socket_pair()
 
862
        server = medium.SmartServerSocketStreamMedium(
 
863
            server_sock, None)
 
864
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
 
865
        server._serve_one_request(fake_protocol)
 
866
        # recv should not block, because the other end of the socket has been
 
867
        # closed.
 
868
        self.assertEqual('', client_sock.recv(1))
 
869
        self.assertTrue(server.finished)
 
870
        
 
871
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
 
872
        to_server = StringIO('')
 
873
        from_server = StringIO()
 
874
        server = medium.SmartServerPipeStreamMedium(
 
875
            to_server, from_server, None)
 
876
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
 
877
        self.assertRaises(
 
878
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
 
879
        self.assertEqual('', from_server.getvalue())
 
880
 
 
881
    def test_socket_stream_keyboard_interrupt_handling(self):
 
882
        server_sock, client_sock = self.portable_socket_pair()
 
883
        server = medium.SmartServerSocketStreamMedium(
 
884
            server_sock, None)
 
885
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
 
886
        self.assertRaises(
 
887
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
 
888
        server_sock.close()
 
889
        self.assertEqual('', client_sock.recv(1))
 
890
 
 
891
    def build_protocol_pipe_like(self, bytes):
 
892
        to_server = StringIO(bytes)
 
893
        from_server = StringIO()
 
894
        server = medium.SmartServerPipeStreamMedium(
 
895
            to_server, from_server, None)
 
896
        return server._build_protocol()
 
897
 
 
898
    def build_protocol_socket(self, bytes):
 
899
        server_sock, client_sock = self.portable_socket_pair()
 
900
        server = medium.SmartServerSocketStreamMedium(
 
901
            server_sock, None)
 
902
        client_sock.sendall(bytes)
 
903
        client_sock.close()
 
904
        return server._build_protocol()
 
905
 
 
906
    def assertProtocolOne(self, server_protocol):
 
907
        # Use assertIs because assertIsInstance will wrongly pass
 
908
        # SmartServerRequestProtocolTwo (because it subclasses
 
909
        # SmartServerRequestProtocolOne).
 
910
        self.assertIs(
 
911
            type(server_protocol), protocol.SmartServerRequestProtocolOne)
 
912
 
 
913
    def assertProtocolTwo(self, server_protocol):
 
914
        self.assertIsInstance(
 
915
            server_protocol, protocol.SmartServerRequestProtocolTwo)
 
916
 
 
917
    def test_pipe_like_build_protocol_empty_bytes(self):
 
918
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
919
        server_protocol = self.build_protocol_pipe_like('')
 
920
        self.assertProtocolOne(server_protocol)
 
921
        
 
922
    def test_socket_like_build_protocol_empty_bytes(self):
 
923
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
924
        server_protocol = self.build_protocol_socket('')
 
925
        self.assertProtocolOne(server_protocol)
 
926
 
 
927
    def test_pipe_like_build_protocol_non_two(self):
 
928
        # A request that doesn't start with "bzr request 2\n" is version one.
 
929
        server_protocol = self.build_protocol_pipe_like('abc\n')
 
930
        self.assertProtocolOne(server_protocol)
 
931
 
 
932
    def test_socket_build_protocol_non_two(self):
 
933
        # A request that doesn't start with "bzr request 2\n" is version one.
 
934
        server_protocol = self.build_protocol_socket('abc\n')
 
935
        self.assertProtocolOne(server_protocol)
 
936
 
 
937
    def test_pipe_like_build_protocol_two(self):
 
938
        # A request that starts with "bzr request 2\n" is version two.
 
939
        server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
 
940
        self.assertProtocolTwo(server_protocol)
 
941
 
 
942
    def test_socket_build_protocol_two(self):
 
943
        # A request that starts with "bzr request 2\n" is version two.
 
944
        server_protocol = self.build_protocol_socket('bzr request 2\n')
 
945
        self.assertProtocolTwo(server_protocol)
 
946
 
 
947
 
 
948
class TestGetProtocolFactoryForBytes(tests.TestCase):
 
949
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
 
950
    should use to decode a given request.  Any bytes not part of the version
 
951
    marker string (and thus part of the actual request) are returned alongside
 
952
    the protocol factory.
 
953
    """
 
954
 
 
955
    def test_version_three(self):
 
956
        result = medium._get_protocol_factory_for_bytes(
 
957
            'bzr message 3 (bzr 1.6)\nextra bytes')
 
958
        protocol_factory, remainder = result
 
959
        self.assertEqual(
 
960
            protocol.build_server_protocol_three, protocol_factory)
 
961
        self.assertEqual('extra bytes', remainder)
 
962
        
 
963
    def test_version_two(self):
 
964
        result = medium._get_protocol_factory_for_bytes(
 
965
            'bzr request 2\nextra bytes')
 
966
        protocol_factory, remainder = result
 
967
        self.assertEqual(
 
968
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
 
969
        self.assertEqual('extra bytes', remainder)
 
970
        
 
971
    def test_version_one(self):
 
972
        """Version one requests have no version markers."""
 
973
        result = medium._get_protocol_factory_for_bytes('anything\n')
 
974
        protocol_factory, remainder = result
 
975
        self.assertEqual(
 
976
            protocol.SmartServerRequestProtocolOne, protocol_factory)
 
977
        self.assertEqual('anything\n', remainder)
 
978
        
 
979
 
 
980
class TestSmartTCPServer(tests.TestCase):
 
981
 
 
982
    def test_get_error_unexpected(self):
 
983
        """Error reported by server with no specific representation"""
 
984
        self._captureVar('BZR_NO_SMART_VFS', None)
 
985
        class FlakyTransport(object):
 
986
            base = 'a_url'
 
987
            def external_url(self):
 
988
                return self.base
 
989
            def get_bytes(self, path):
 
990
                raise Exception("some random exception from inside server")
 
991
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
 
992
        smart_server.start_background_thread('-' + self.id())
 
993
        try:
 
994
            transport = remote.RemoteTCPTransport(smart_server.get_url())
 
995
            try:
 
996
                transport.get('something')
 
997
            except errors.TransportError, e:
 
998
                self.assertContainsRe(str(e), 'some random exception')
 
999
            else:
 
1000
                self.fail("get did not raise expected error")
 
1001
            transport.disconnect()
 
1002
        finally:
 
1003
            smart_server.stop_background_thread()
 
1004
 
 
1005
 
 
1006
class SmartTCPTests(tests.TestCase):
 
1007
    """Tests for connection/end to end behaviour using the TCP server.
 
1008
 
 
1009
    All of these tests are run with a server running on another thread serving
 
1010
    a MemoryTransport, and a connection to it already open.
 
1011
 
 
1012
    the server is obtained by calling self.setUpServer(readonly=False).
 
1013
    """
 
1014
 
 
1015
    def setUpServer(self, readonly=False, backing_transport=None):
 
1016
        """Setup the server.
 
1017
 
 
1018
        :param readonly: Create a readonly server.
 
1019
        """
 
1020
        if not backing_transport:
 
1021
            self.backing_transport = memory.MemoryTransport()
 
1022
        else:
 
1023
            self.backing_transport = backing_transport
 
1024
        if readonly:
 
1025
            self.real_backing_transport = self.backing_transport
 
1026
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
 
1027
        self.server = server.SmartTCPServer(self.backing_transport)
 
1028
        self.server.start_background_thread('-' + self.id())
 
1029
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
 
1030
        self.addCleanup(self.tearDownServer)
 
1031
 
 
1032
    def tearDownServer(self):
 
1033
        if getattr(self, 'transport', None):
 
1034
            self.transport.disconnect()
 
1035
            del self.transport
 
1036
        if getattr(self, 'server', None):
 
1037
            self.server.stop_background_thread()
 
1038
            del self.server
 
1039
 
 
1040
 
 
1041
class TestServerSocketUsage(SmartTCPTests):
 
1042
 
 
1043
    def test_server_setup_teardown(self):
 
1044
        """It should be safe to teardown the server with no requests."""
 
1045
        self.setUpServer()
 
1046
        server = self.server
 
1047
        transport = remote.RemoteTCPTransport(self.server.get_url())
 
1048
        self.tearDownServer()
 
1049
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
1050
 
 
1051
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
 
1052
        """The server should close its listening socket when it's stopped."""
 
1053
        self.setUpServer()
 
1054
        server = self.server
 
1055
        self.transport.has('.')
 
1056
        self.tearDownServer()
 
1057
        # if the listening socket has closed, we should get a BADFD error
 
1058
        # when connecting, rather than a hang.
 
1059
        transport = remote.RemoteTCPTransport(server.get_url())
 
1060
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
1061
 
 
1062
 
 
1063
class WritableEndToEndTests(SmartTCPTests):
 
1064
    """Client to server tests that require a writable transport."""
 
1065
 
 
1066
    def setUp(self):
 
1067
        super(WritableEndToEndTests, self).setUp()
 
1068
        self.setUpServer()
 
1069
 
 
1070
    def test_start_tcp_server(self):
 
1071
        url = self.server.get_url()
 
1072
        self.assertContainsRe(url, r'^bzr://127\.0\.0\.1:[0-9]{2,}/')
 
1073
 
 
1074
    def test_smart_transport_has(self):
 
1075
        """Checking for file existence over smart."""
 
1076
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1077
        self.backing_transport.put_bytes("foo", "contents of foo\n")
 
1078
        self.assertTrue(self.transport.has("foo"))
 
1079
        self.assertFalse(self.transport.has("non-foo"))
 
1080
 
 
1081
    def test_smart_transport_get(self):
 
1082
        """Read back a file over smart."""
 
1083
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1084
        self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
 
1085
        fp = self.transport.get("foo")
 
1086
        self.assertEqual('contents\nof\nfoo\n', fp.read())
 
1087
 
 
1088
    def test_get_error_enoent(self):
 
1089
        """Error reported from server getting nonexistent file."""
 
1090
        # The path in a raised NoSuchFile exception should be the precise path
 
1091
        # asked for by the client. This gives meaningful and unsurprising errors
 
1092
        # for users.
 
1093
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1094
        try:
 
1095
            self.transport.get('not%20a%20file')
 
1096
        except errors.NoSuchFile, e:
 
1097
            self.assertEqual('not%20a%20file', e.path)
 
1098
        else:
 
1099
            self.fail("get did not raise expected error")
 
1100
 
 
1101
    def test_simple_clone_conn(self):
 
1102
        """Test that cloning reuses the same connection."""
 
1103
        # we create a real connection not a loopback one, but it will use the
 
1104
        # same server and pipes
 
1105
        conn2 = self.transport.clone('.')
 
1106
        self.assertIs(self.transport.get_smart_medium(),
 
1107
                      conn2.get_smart_medium())
 
1108
 
 
1109
    def test__remote_path(self):
 
1110
        self.assertEquals('/foo/bar',
 
1111
                          self.transport._remote_path('foo/bar'))
 
1112
 
 
1113
    def test_clone_changes_base(self):
 
1114
        """Cloning transport produces one with a new base location"""
 
1115
        conn2 = self.transport.clone('subdir')
 
1116
        self.assertEquals(self.transport.base + 'subdir/',
 
1117
                          conn2.base)
 
1118
 
 
1119
    def test_open_dir(self):
 
1120
        """Test changing directory"""
 
1121
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1122
        transport = self.transport
 
1123
        self.backing_transport.mkdir('toffee')
 
1124
        self.backing_transport.mkdir('toffee/apple')
 
1125
        self.assertEquals('/toffee', transport._remote_path('toffee'))
 
1126
        toffee_trans = transport.clone('toffee')
 
1127
        # Check that each transport has only the contents of its directory
 
1128
        # directly visible. If state was being held in the wrong object, it's
 
1129
        # conceivable that cloning a transport would alter the state of the
 
1130
        # cloned-from transport.
 
1131
        self.assertTrue(transport.has('toffee'))
 
1132
        self.assertFalse(toffee_trans.has('toffee'))
 
1133
        self.assertFalse(transport.has('apple'))
 
1134
        self.assertTrue(toffee_trans.has('apple'))
 
1135
 
 
1136
    def test_open_bzrdir(self):
 
1137
        """Open an existing bzrdir over smart transport"""
 
1138
        transport = self.transport
 
1139
        t = self.backing_transport
 
1140
        bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
 
1141
        result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
 
1142
 
 
1143
 
 
1144
class ReadOnlyEndToEndTests(SmartTCPTests):
 
1145
    """Tests from the client to the server using a readonly backing transport."""
 
1146
 
 
1147
    def test_mkdir_error_readonly(self):
 
1148
        """TransportNotPossible should be preserved from the backing transport."""
 
1149
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1150
        self.setUpServer(readonly=True)
 
1151
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
 
1152
            'foo')
 
1153
 
 
1154
 
 
1155
class TestServerHooks(SmartTCPTests):
 
1156
 
 
1157
    def capture_server_call(self, backing_urls, public_url):
 
1158
        """Record a server_started|stopped hook firing."""
 
1159
        self.hook_calls.append((backing_urls, public_url))
 
1160
 
 
1161
    def test_server_started_hook_memory(self):
 
1162
        """The server_started hook fires when the server is started."""
 
1163
        self.hook_calls = []
 
1164
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1165
            self.capture_server_call, None)
 
1166
        self.setUpServer()
 
1167
        # at this point, the server will be starting a thread up.
 
1168
        # there is no indicator at the moment, so bodge it by doing a request.
 
1169
        self.transport.has('.')
 
1170
        # The default test server uses MemoryTransport and that has no external
 
1171
        # url:
 
1172
        self.assertEqual([([self.backing_transport.base], self.transport.base)],
 
1173
            self.hook_calls)
 
1174
 
 
1175
    def test_server_started_hook_file(self):
 
1176
        """The server_started hook fires when the server is started."""
 
1177
        self.hook_calls = []
 
1178
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1179
            self.capture_server_call, None)
 
1180
        self.setUpServer(backing_transport=get_transport("."))
 
1181
        # at this point, the server will be starting a thread up.
 
1182
        # there is no indicator at the moment, so bodge it by doing a request.
 
1183
        self.transport.has('.')
 
1184
        # The default test server uses MemoryTransport and that has no external
 
1185
        # url:
 
1186
        self.assertEqual([([
 
1187
            self.backing_transport.base, self.backing_transport.external_url()],
 
1188
             self.transport.base)],
 
1189
            self.hook_calls)
 
1190
 
 
1191
    def test_server_stopped_hook_simple_memory(self):
 
1192
        """The server_stopped hook fires when the server is stopped."""
 
1193
        self.hook_calls = []
 
1194
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1195
            self.capture_server_call, None)
 
1196
        self.setUpServer()
 
1197
        result = [([self.backing_transport.base], self.transport.base)]
 
1198
        # check the stopping message isn't emitted up front.
 
1199
        self.assertEqual([], self.hook_calls)
 
1200
        # nor after a single message
 
1201
        self.transport.has('.')
 
1202
        self.assertEqual([], self.hook_calls)
 
1203
        # clean up the server
 
1204
        self.tearDownServer()
 
1205
        # now it should have fired.
 
1206
        self.assertEqual(result, self.hook_calls)
 
1207
 
 
1208
    def test_server_stopped_hook_simple_file(self):
 
1209
        """The server_stopped hook fires when the server is stopped."""
 
1210
        self.hook_calls = []
 
1211
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1212
            self.capture_server_call, None)
 
1213
        self.setUpServer(backing_transport=get_transport("."))
 
1214
        result = [(
 
1215
            [self.backing_transport.base, self.backing_transport.external_url()]
 
1216
            , self.transport.base)]
 
1217
        # check the stopping message isn't emitted up front.
 
1218
        self.assertEqual([], self.hook_calls)
 
1219
        # nor after a single message
 
1220
        self.transport.has('.')
 
1221
        self.assertEqual([], self.hook_calls)
 
1222
        # clean up the server
 
1223
        self.tearDownServer()
 
1224
        # now it should have fired.
 
1225
        self.assertEqual(result, self.hook_calls)
 
1226
 
 
1227
# TODO: test that when the server suffers an exception that it calls the
 
1228
# server-stopped hook.
 
1229
 
 
1230
 
 
1231
class SmartServerCommandTests(tests.TestCaseWithTransport):
 
1232
    """Tests that call directly into the command objects, bypassing the network
 
1233
    and the request dispatching.
 
1234
 
 
1235
    Note: these tests are rudimentary versions of the command object tests in
 
1236
    test_smart.py.
 
1237
    """
 
1238
        
 
1239
    def test_hello(self):
 
1240
        cmd = _mod_request.HelloRequest(None, '/')
 
1241
        response = cmd.execute()
 
1242
        self.assertEqual(('ok', '2'), response.args)
 
1243
        self.assertEqual(None, response.body)
 
1244
        
 
1245
    def test_get_bundle(self):
 
1246
        from bzrlib.bundle import serializer
 
1247
        wt = self.make_branch_and_tree('.')
 
1248
        self.build_tree_contents([('hello', 'hello world')])
 
1249
        wt.add('hello')
 
1250
        rev_id = wt.commit('add hello')
 
1251
        
 
1252
        cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
 
1253
        response = cmd.execute('.', rev_id)
 
1254
        bundle = serializer.read_bundle(StringIO(response.body))
 
1255
        self.assertEqual((), response.args)
 
1256
 
 
1257
 
 
1258
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
 
1259
    """Test that call directly into the handler logic, bypassing the network."""
 
1260
 
 
1261
    def setUp(self):
 
1262
        super(SmartServerRequestHandlerTests, self).setUp()
 
1263
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1264
 
 
1265
    def build_handler(self, transport):
 
1266
        """Returns a handler for the commands in protocol version one."""
 
1267
        return _mod_request.SmartServerRequestHandler(
 
1268
            transport, _mod_request.request_handlers, '/')
 
1269
 
 
1270
    def test_construct_request_handler(self):
 
1271
        """Constructing a request handler should be easy and set defaults."""
 
1272
        handler = _mod_request.SmartServerRequestHandler(None, commands=None,
 
1273
                root_client_path='/')
 
1274
        self.assertFalse(handler.finished_reading)
 
1275
 
 
1276
    def test_hello(self):
 
1277
        handler = self.build_handler(None)
 
1278
        handler.dispatch_command('hello', ())
 
1279
        self.assertEqual(('ok', '2'), handler.response.args)
 
1280
        self.assertEqual(None, handler.response.body)
 
1281
        
 
1282
    def test_disable_vfs_handler_classes_via_environment(self):
 
1283
        # VFS handler classes will raise an error from "execute" if
 
1284
        # BZR_NO_SMART_VFS is set.
 
1285
        handler = vfs.HasRequest(None, '/')
 
1286
        # set environment variable after construction to make sure it's
 
1287
        # examined.
 
1288
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
 
1289
        # has called _captureVar, so it will be restored to the right state
 
1290
        # afterwards.
 
1291
        os.environ['BZR_NO_SMART_VFS'] = ''
 
1292
        self.assertRaises(errors.DisabledMethod, handler.execute)
 
1293
 
 
1294
    def test_readonly_exception_becomes_transport_not_possible(self):
 
1295
        """The response for a read-only error is ('ReadOnlyError')."""
 
1296
        handler = self.build_handler(self.get_readonly_transport())
 
1297
        # send a mkdir for foo, with no explicit mode - should fail.
 
1298
        handler.dispatch_command('mkdir', ('foo', ''))
 
1299
        # and the failure should be an explicit ReadOnlyError
 
1300
        self.assertEqual(("ReadOnlyError", ), handler.response.args)
 
1301
        # XXX: TODO: test that other TransportNotPossible errors are
 
1302
        # presented as TransportNotPossible - not possible to do that
 
1303
        # until I figure out how to trigger that relatively cleanly via
 
1304
        # the api. RBC 20060918
 
1305
 
 
1306
    def test_hello_has_finished_body_on_dispatch(self):
 
1307
        """The 'hello' command should set finished_reading."""
 
1308
        handler = self.build_handler(None)
 
1309
        handler.dispatch_command('hello', ())
 
1310
        self.assertTrue(handler.finished_reading)
 
1311
        self.assertNotEqual(None, handler.response)
 
1312
 
 
1313
    def test_put_bytes_non_atomic(self):
 
1314
        """'put_...' should set finished_reading after reading the bytes."""
 
1315
        handler = self.build_handler(self.get_transport())
 
1316
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
 
1317
        self.assertFalse(handler.finished_reading)
 
1318
        handler.accept_body('1234')
 
1319
        self.assertFalse(handler.finished_reading)
 
1320
        handler.accept_body('5678')
 
1321
        handler.end_of_body()
 
1322
        self.assertTrue(handler.finished_reading)
 
1323
        self.assertEqual(('ok', ), handler.response.args)
 
1324
        self.assertEqual(None, handler.response.body)
 
1325
        
 
1326
    def test_readv_accept_body(self):
 
1327
        """'readv' should set finished_reading after reading offsets."""
 
1328
        self.build_tree(['a-file'])
 
1329
        handler = self.build_handler(self.get_readonly_transport())
 
1330
        handler.dispatch_command('readv', ('a-file', ))
 
1331
        self.assertFalse(handler.finished_reading)
 
1332
        handler.accept_body('2,')
 
1333
        self.assertFalse(handler.finished_reading)
 
1334
        handler.accept_body('3')
 
1335
        handler.end_of_body()
 
1336
        self.assertTrue(handler.finished_reading)
 
1337
        self.assertEqual(('readv', ), handler.response.args)
 
1338
        # co - nte - nt of a-file is the file contents we are extracting from.
 
1339
        self.assertEqual('nte', handler.response.body)
 
1340
 
 
1341
    def test_readv_short_read_response_contents(self):
 
1342
        """'readv' when a short read occurs sets the response appropriately."""
 
1343
        self.build_tree(['a-file'])
 
1344
        handler = self.build_handler(self.get_readonly_transport())
 
1345
        handler.dispatch_command('readv', ('a-file', ))
 
1346
        # read beyond the end of the file.
 
1347
        handler.accept_body('100,1')
 
1348
        handler.end_of_body()
 
1349
        self.assertTrue(handler.finished_reading)
 
1350
        self.assertEqual(('ShortReadvError', './a-file', '100', '1', '0'),
 
1351
            handler.response.args)
 
1352
        self.assertEqual(None, handler.response.body)
 
1353
 
 
1354
 
 
1355
class RemoteTransportRegistration(tests.TestCase):
 
1356
 
 
1357
    def test_registration(self):
 
1358
        t = get_transport('bzr+ssh://example.com/path')
 
1359
        self.assertIsInstance(t, remote.RemoteSSHTransport)
 
1360
        self.assertEqual('example.com', t._host)
 
1361
 
 
1362
    def test_bzr_https(self):
 
1363
        # https://bugs.launchpad.net/bzr/+bug/128456
 
1364
        t = get_transport('bzr+https://example.com/path')
 
1365
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
 
1366
        self.assertStartsWith(
 
1367
            t._http_transport.base,
 
1368
            'https://')
 
1369
 
 
1370
 
 
1371
class TestRemoteTransport(tests.TestCase):
 
1372
        
 
1373
    def test_use_connection_factory(self):
 
1374
        # We want to be able to pass a client as a parameter to RemoteTransport.
 
1375
        input = StringIO('ok\n3\nbardone\n')
 
1376
        output = StringIO()
 
1377
        client_medium = medium.SmartSimplePipesClientMedium(
 
1378
            input, output, 'base')
 
1379
        transport = remote.RemoteTransport(
 
1380
            'bzr://localhost/', medium=client_medium)
 
1381
        # Disable version detection.
 
1382
        client_medium._protocol_version = 1
 
1383
 
 
1384
        # We want to make sure the client is used when the first remote
 
1385
        # method is called.  No data should have been sent, or read.
 
1386
        self.assertEqual(0, input.tell())
 
1387
        self.assertEqual('', output.getvalue())
 
1388
 
 
1389
        # Now call a method that should result in one request: as the
 
1390
        # transport makes its own protocol instances, we check on the wire.
 
1391
        # XXX: TODO: give the transport a protocol factory, which can make
 
1392
        # an instrumented protocol for us.
 
1393
        self.assertEqual('bar', transport.get_bytes('foo'))
 
1394
        # only the needed data should have been sent/received.
 
1395
        self.assertEqual(13, input.tell())
 
1396
        self.assertEqual('get\x01/foo\n', output.getvalue())
 
1397
 
 
1398
    def test__translate_error_readonly(self):
 
1399
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
 
1400
        client_medium = medium.SmartSimplePipesClientMedium(None, None, 'base')
 
1401
        transport = remote.RemoteTransport(
 
1402
            'bzr://localhost/', medium=client_medium)
 
1403
        self.assertRaises(errors.TransportNotPossible,
 
1404
            transport._translate_error, ("ReadOnlyError", ))
 
1405
 
 
1406
 
 
1407
class TestSmartProtocol(tests.TestCase):
 
1408
    """Base class for smart protocol tests.
 
1409
 
 
1410
    Each test case gets a smart_server and smart_client created during setUp().
 
1411
 
 
1412
    It is planned that the client can be called with self.call_client() giving
 
1413
    it an expected server response, which will be fed into it when it tries to
 
1414
    read. Likewise, self.call_server will call a servers method with a canned
 
1415
    serialised client request. Output done by the client or server for these
 
1416
    calls will be captured to self.to_server and self.to_client. Each element
 
1417
    in the list is a write call from the client or server respectively.
 
1418
 
 
1419
    Subclasses can override client_protocol_class and server_protocol_class.
 
1420
    """
 
1421
 
 
1422
    request_encoder = None
 
1423
    response_decoder = None
 
1424
    server_protocol_class = None
 
1425
    client_protocol_class = None
 
1426
 
 
1427
    def make_client_protocol_and_output(self, input_bytes=None):
 
1428
        """
 
1429
        :returns: a Request
 
1430
        """
 
1431
        # This is very similar to
 
1432
        # bzrlib.smart.client._SmartClient._build_client_protocol
 
1433
        # XXX: make this use _SmartClient!
 
1434
        if input_bytes is None:
 
1435
            input = StringIO()
 
1436
        else:
 
1437
            input = StringIO(input_bytes)
 
1438
        output = StringIO()
 
1439
        client_medium = medium.SmartSimplePipesClientMedium(
 
1440
            input, output, 'base')
 
1441
        request = client_medium.get_request()
 
1442
        if self.client_protocol_class is not None:
 
1443
            client_protocol = self.client_protocol_class(request)
 
1444
            return client_protocol, client_protocol, output
 
1445
        else:
 
1446
            self.assertNotEqual(None, self.request_encoder)
 
1447
            self.assertNotEqual(None, self.response_decoder)
 
1448
            requester = self.request_encoder(request)
 
1449
            response_handler = message.ConventionalResponseHandler()
 
1450
            response_protocol = self.response_decoder(
 
1451
                response_handler, expect_version_marker=True)
 
1452
            response_handler.setProtoAndMediumRequest(
 
1453
                response_protocol, request)
 
1454
            return requester, response_handler, output
 
1455
 
 
1456
    def make_client_protocol(self, input_bytes=None):
 
1457
        result = self.make_client_protocol_and_output(input_bytes=input_bytes)
 
1458
        requester, response_handler, output = result
 
1459
        return requester, response_handler
 
1460
 
 
1461
    def make_server_protocol(self):
 
1462
        out_stream = StringIO()
 
1463
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1464
        return smart_protocol, out_stream
 
1465
 
 
1466
    def setUp(self):
 
1467
        super(TestSmartProtocol, self).setUp()
 
1468
        self.response_marker = getattr(
 
1469
            self.client_protocol_class, 'response_marker', None)
 
1470
        self.request_marker = getattr(
 
1471
            self.client_protocol_class, 'request_marker', None)
 
1472
 
 
1473
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
 
1474
        requester):
 
1475
        """Check that smart (de)serialises offsets as expected.
 
1476
        
 
1477
        We check both serialisation and deserialisation at the same time
 
1478
        to ensure that the round tripping cannot skew: both directions should
 
1479
        be as expected.
 
1480
        
 
1481
        :param expected_offsets: a readv offset list.
 
1482
        :param expected_seralised: an expected serial form of the offsets.
 
1483
        """
 
1484
        # XXX: '_deserialise_offsets' should be a method of the
 
1485
        # SmartServerRequestProtocol in future.
 
1486
        readv_cmd = vfs.ReadvRequest(None, '/')
 
1487
        offsets = readv_cmd._deserialise_offsets(expected_serialised)
 
1488
        self.assertEqual(expected_offsets, offsets)
 
1489
        serialised = requester._serialise_offsets(offsets)
 
1490
        self.assertEqual(expected_serialised, serialised)
 
1491
 
 
1492
    def build_protocol_waiting_for_body(self):
 
1493
        smart_protocol, out_stream = self.make_server_protocol()
 
1494
        smart_protocol._has_dispatched = True
 
1495
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
 
1496
            None, _mod_request.request_handlers, '/')
 
1497
        class FakeCommand(object):
 
1498
            def do_body(cmd, body_bytes):
 
1499
                self.end_received = True
 
1500
                self.assertEqual('abcdefg', body_bytes)
 
1501
                return _mod_request.SuccessfulSmartServerResponse(('ok', ))
 
1502
        smart_protocol.request._command = FakeCommand()
 
1503
        # Call accept_bytes to make sure that internal state like _body_decoder
 
1504
        # is initialised.  This test should probably be given a clearer
 
1505
        # interface to work with that will not cause this inconsistency.
 
1506
        #   -- Andrew Bennetts, 2006-09-28
 
1507
        smart_protocol.accept_bytes('')
 
1508
        return smart_protocol
 
1509
 
 
1510
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
 
1511
            input_tuples):
 
1512
        """Assert that each input_tuple serialises as expected_bytes, and the
 
1513
        bytes deserialise as expected_tuple.
 
1514
        """
 
1515
        # check the encoding of the server for all input_tuples matches
 
1516
        # expected bytes
 
1517
        for input_tuple in input_tuples:
 
1518
            server_protocol, server_output = self.make_server_protocol()
 
1519
            server_protocol._send_response(
 
1520
                _mod_request.SuccessfulSmartServerResponse(input_tuple))
 
1521
            self.assertEqual(expected_bytes, server_output.getvalue())
 
1522
        # check the decoding of the client smart_protocol from expected_bytes:
 
1523
        requester, response_handler = self.make_client_protocol(expected_bytes)
 
1524
        requester.call('foo')
 
1525
        self.assertEqual(expected_tuple, response_handler.read_response_tuple())
 
1526
 
 
1527
 
 
1528
class CommonSmartProtocolTestMixin(object):
 
1529
 
 
1530
    def test_connection_closed_reporting(self):
 
1531
        requester, response_handler = self.make_client_protocol()
 
1532
        requester.call('hello')
 
1533
        ex = self.assertRaises(errors.ConnectionReset,
 
1534
            response_handler.read_response_tuple)
 
1535
        self.assertEqual("Connection closed: "
 
1536
            "please check connectivity and permissions "
 
1537
            "(and try -Dhpss if further diagnosis is required)", str(ex))
 
1538
 
 
1539
    def test_server_offset_serialisation(self):
 
1540
        """The Smart protocol serialises offsets as a comma and \n string.
 
1541
 
 
1542
        We check a number of boundary cases are as expected: empty, one offset,
 
1543
        one with the order of reads not increasing (an out of order read), and
 
1544
        one that should coalesce.
 
1545
        """
 
1546
        requester, response_handler = self.make_client_protocol()
 
1547
        self.assertOffsetSerialisation([], '', requester)
 
1548
        self.assertOffsetSerialisation([(1,2)], '1,2', requester)
 
1549
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
 
1550
            requester)
 
1551
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
 
1552
            '1,2\n3,4\n100,200', requester)
 
1553
 
 
1554
 
 
1555
class TestVersionOneFeaturesInProtocolOne(
 
1556
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1557
    """Tests for version one smart protocol features as implemeted by version
 
1558
    one."""
 
1559
 
 
1560
    client_protocol_class = protocol.SmartClientRequestProtocolOne
 
1561
    server_protocol_class = protocol.SmartServerRequestProtocolOne
 
1562
 
 
1563
    def test_construct_version_one_server_protocol(self):
 
1564
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
 
1565
        self.assertEqual('', smart_protocol.unused_data)
 
1566
        self.assertEqual('', smart_protocol.in_buffer)
 
1567
        self.assertFalse(smart_protocol._has_dispatched)
 
1568
        self.assertEqual(1, smart_protocol.next_read_size())
 
1569
 
 
1570
    def test_construct_version_one_client_protocol(self):
 
1571
        # we can construct a client protocol from a client medium request
 
1572
        output = StringIO()
 
1573
        client_medium = medium.SmartSimplePipesClientMedium(
 
1574
            None, output, 'base')
 
1575
        request = client_medium.get_request()
 
1576
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1577
 
 
1578
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1579
        out_stream = StringIO()
 
1580
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1581
            None, out_stream.write)
 
1582
        smart_protocol.accept_bytes('abc')
 
1583
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1584
        smart_protocol.accept_bytes('\n')
 
1585
        self.assertEqual(
 
1586
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1587
            out_stream.getvalue())
 
1588
        self.assertTrue(smart_protocol._has_dispatched)
 
1589
        self.assertEqual(0, smart_protocol.next_read_size())
 
1590
 
 
1591
    def test_accept_body_bytes_to_protocol(self):
 
1592
        protocol = self.build_protocol_waiting_for_body()
 
1593
        self.assertEqual(6, protocol.next_read_size())
 
1594
        protocol.accept_bytes('7\nabc')
 
1595
        self.assertEqual(9, protocol.next_read_size())
 
1596
        protocol.accept_bytes('defgd')
 
1597
        protocol.accept_bytes('one\n')
 
1598
        self.assertEqual(0, protocol.next_read_size())
 
1599
        self.assertTrue(self.end_received)
 
1600
 
 
1601
    def test_accept_request_and_body_all_at_once(self):
 
1602
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1603
        mem_transport = memory.MemoryTransport()
 
1604
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1605
        out_stream = StringIO()
 
1606
        smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
 
1607
                out_stream.write)
 
1608
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1609
        self.assertEqual(0, smart_protocol.next_read_size())
 
1610
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
 
1611
        self.assertEqual('', smart_protocol.unused_data)
 
1612
        self.assertEqual('', smart_protocol.in_buffer)
 
1613
 
 
1614
    def test_accept_excess_bytes_are_preserved(self):
 
1615
        out_stream = StringIO()
 
1616
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1617
            None, out_stream.write)
 
1618
        smart_protocol.accept_bytes('hello\nhello\n')
 
1619
        self.assertEqual("ok\x012\n", out_stream.getvalue())
 
1620
        self.assertEqual("hello\n", smart_protocol.unused_data)
 
1621
        self.assertEqual("", smart_protocol.in_buffer)
 
1622
 
 
1623
    def test_accept_excess_bytes_after_body(self):
 
1624
        protocol = self.build_protocol_waiting_for_body()
 
1625
        protocol.accept_bytes('7\nabcdefgdone\nX')
 
1626
        self.assertTrue(self.end_received)
 
1627
        self.assertEqual("X", protocol.unused_data)
 
1628
        self.assertEqual("", protocol.in_buffer)
 
1629
        protocol.accept_bytes('Y')
 
1630
        self.assertEqual("XY", protocol.unused_data)
 
1631
        self.assertEqual("", protocol.in_buffer)
 
1632
 
 
1633
    def test_accept_excess_bytes_after_dispatch(self):
 
1634
        out_stream = StringIO()
 
1635
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1636
            None, out_stream.write)
 
1637
        smart_protocol.accept_bytes('hello\n')
 
1638
        self.assertEqual("ok\x012\n", out_stream.getvalue())
 
1639
        smart_protocol.accept_bytes('hel')
 
1640
        self.assertEqual("hel", smart_protocol.unused_data)
 
1641
        smart_protocol.accept_bytes('lo\n')
 
1642
        self.assertEqual("hello\n", smart_protocol.unused_data)
 
1643
        self.assertEqual("", smart_protocol.in_buffer)
 
1644
 
 
1645
    def test__send_response_sets_finished_reading(self):
 
1646
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1647
            None, lambda x: None)
 
1648
        self.assertEqual(1, smart_protocol.next_read_size())
 
1649
        smart_protocol._send_response(
 
1650
            _mod_request.SuccessfulSmartServerResponse(('x',)))
 
1651
        self.assertEqual(0, smart_protocol.next_read_size())
 
1652
 
 
1653
    def test__send_response_errors_with_base_response(self):
 
1654
        """Ensure that only the Successful/Failed subclasses are used."""
 
1655
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1656
            None, lambda x: None)
 
1657
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1658
            _mod_request.SmartServerResponse(('x',)))
 
1659
 
 
1660
    def test_query_version(self):
 
1661
        """query_version on a SmartClientProtocolOne should return a number.
 
1662
        
 
1663
        The protocol provides the query_version because the domain level clients
 
1664
        may all need to be able to probe for capabilities.
 
1665
        """
 
1666
        # What we really want to test here is that SmartClientProtocolOne calls
 
1667
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1668
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1669
        # the error if the response is a non-understood version.
 
1670
        input = StringIO('ok\x012\n')
 
1671
        output = StringIO()
 
1672
        client_medium = medium.SmartSimplePipesClientMedium(
 
1673
            input, output, 'base')
 
1674
        request = client_medium.get_request()
 
1675
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1676
        self.assertEqual(2, smart_protocol.query_version())
 
1677
 
 
1678
    def test_client_call_empty_response(self):
 
1679
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1680
        # when the parsed line is an empty line, and results in a tuple with
 
1681
        # one element - an empty string.
 
1682
        self.assertServerToClientEncoding('\n', ('', ), [(), ('', )])
 
1683
 
 
1684
    def test_client_call_three_element_response(self):
 
1685
        # protocol.call() can get back tuples of other lengths. A three element
 
1686
        # tuple should be unpacked as three strings.
 
1687
        self.assertServerToClientEncoding('a\x01b\x0134\n', ('a', 'b', '34'),
 
1688
            [('a', 'b', '34')])
 
1689
 
 
1690
    def test_client_call_with_body_bytes_uploads(self):
 
1691
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1692
        # wire.
 
1693
        expected_bytes = "foo\n7\nabcdefgdone\n"
 
1694
        input = StringIO("\n")
 
1695
        output = StringIO()
 
1696
        client_medium = medium.SmartSimplePipesClientMedium(
 
1697
            input, output, 'base')
 
1698
        request = client_medium.get_request()
 
1699
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1700
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1701
        self.assertEqual(expected_bytes, output.getvalue())
 
1702
 
 
1703
    def test_client_call_with_body_readv_array(self):
 
1704
        # protocol.call_with_upload should encode the readv array and then
 
1705
        # length-prefix the bytes onto the wire.
 
1706
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
 
1707
        input = StringIO("\n")
 
1708
        output = StringIO()
 
1709
        client_medium = medium.SmartSimplePipesClientMedium(
 
1710
            input, output, 'base')
 
1711
        request = client_medium.get_request()
 
1712
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1713
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1714
        self.assertEqual(expected_bytes, output.getvalue())
 
1715
 
 
1716
    def _test_client_read_response_tuple_raises_UnknownSmartMethod(self,
 
1717
            server_bytes):
 
1718
        input = StringIO(server_bytes)
 
1719
        output = StringIO()
 
1720
        client_medium = medium.SmartSimplePipesClientMedium(
 
1721
            input, output, 'base')
 
1722
        request = client_medium.get_request()
 
1723
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1724
        smart_protocol.call('foo')
 
1725
        self.assertRaises(
 
1726
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
 
1727
        # The request has been finished.  There is no body to read, and
 
1728
        # attempts to read one will fail.
 
1729
        self.assertRaises(
 
1730
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
1731
 
 
1732
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
 
1733
        """read_response_tuple raises UnknownSmartMethod if the response says
 
1734
        the server did not recognise the request.
 
1735
        """
 
1736
        server_bytes = (
 
1737
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
1738
        self._test_client_read_response_tuple_raises_UnknownSmartMethod(
 
1739
            server_bytes)
 
1740
 
 
1741
    def test_client_read_response_tuple_raises_UnknownSmartMethod_0_11(self):
 
1742
        """read_response_tuple also raises UnknownSmartMethod if the response
 
1743
        from a bzr 0.11 says the server did not recognise the request.
 
1744
 
 
1745
        (bzr 0.11 sends a slightly different error message to later versions.)
 
1746
        """
 
1747
        server_bytes = (
 
1748
            "error\x01Generic bzr smart protocol error: bad request u'foo'\n")
 
1749
        self._test_client_read_response_tuple_raises_UnknownSmartMethod(
 
1750
            server_bytes)
 
1751
 
 
1752
    def test_client_read_body_bytes_all(self):
 
1753
        # read_body_bytes should decode the body bytes from the wire into
 
1754
        # a response.
 
1755
        expected_bytes = "1234567"
 
1756
        server_bytes = "ok\n7\n1234567done\n"
 
1757
        input = StringIO(server_bytes)
 
1758
        output = StringIO()
 
1759
        client_medium = medium.SmartSimplePipesClientMedium(
 
1760
            input, output, 'base')
 
1761
        request = client_medium.get_request()
 
1762
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1763
        smart_protocol.call('foo')
 
1764
        smart_protocol.read_response_tuple(True)
 
1765
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
1766
 
 
1767
    def test_client_read_body_bytes_incremental(self):
 
1768
        # test reading a few bytes at a time from the body
 
1769
        # XXX: possibly we should test dribbling the bytes into the stringio
 
1770
        # to make the state machine work harder: however, as we use the
 
1771
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
1772
        # that.
 
1773
        expected_bytes = "1234567"
 
1774
        server_bytes = "ok\n7\n1234567done\n"
 
1775
        input = StringIO(server_bytes)
 
1776
        output = StringIO()
 
1777
        client_medium = medium.SmartSimplePipesClientMedium(
 
1778
            input, output, 'base')
 
1779
        request = client_medium.get_request()
 
1780
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1781
        smart_protocol.call('foo')
 
1782
        smart_protocol.read_response_tuple(True)
 
1783
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1784
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1785
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1786
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
1787
 
 
1788
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
1789
        # cancelling the expected body needs to finish the request, but not
 
1790
        # read any more bytes.
 
1791
        expected_bytes = "1234567"
 
1792
        server_bytes = "ok\n7\n1234567done\n"
 
1793
        input = StringIO(server_bytes)
 
1794
        output = StringIO()
 
1795
        client_medium = medium.SmartSimplePipesClientMedium(
 
1796
            input, output, 'base')
 
1797
        request = client_medium.get_request()
 
1798
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1799
        smart_protocol.call('foo')
 
1800
        smart_protocol.read_response_tuple(True)
 
1801
        smart_protocol.cancel_read_body()
 
1802
        self.assertEqual(3, input.tell())
 
1803
        self.assertRaises(
 
1804
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
1805
 
 
1806
    def test_client_read_body_bytes_interrupted_connection(self):
 
1807
        server_bytes = "ok\n999\nincomplete body"
 
1808
        input = StringIO(server_bytes)
 
1809
        output = StringIO()
 
1810
        client_medium = medium.SmartSimplePipesClientMedium(
 
1811
            input, output, 'base')
 
1812
        request = client_medium.get_request()
 
1813
        smart_protocol = self.client_protocol_class(request)
 
1814
        smart_protocol.call('foo')
 
1815
        smart_protocol.read_response_tuple(True)
 
1816
        self.assertRaises(
 
1817
            errors.ConnectionReset, smart_protocol.read_body_bytes)
 
1818
 
 
1819
 
 
1820
class TestVersionOneFeaturesInProtocolTwo(
 
1821
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
1822
    """Tests for version one smart protocol features as implemeted by version
 
1823
    two.
 
1824
    """
 
1825
 
 
1826
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
1827
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
1828
 
 
1829
    def test_construct_version_two_server_protocol(self):
 
1830
        smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
 
1831
        self.assertEqual('', smart_protocol.unused_data)
 
1832
        self.assertEqual('', smart_protocol.in_buffer)
 
1833
        self.assertFalse(smart_protocol._has_dispatched)
 
1834
        self.assertEqual(1, smart_protocol.next_read_size())
 
1835
 
 
1836
    def test_construct_version_two_client_protocol(self):
 
1837
        # we can construct a client protocol from a client medium request
 
1838
        output = StringIO()
 
1839
        client_medium = medium.SmartSimplePipesClientMedium(
 
1840
            None, output, 'base')
 
1841
        request = client_medium.get_request()
 
1842
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1843
 
 
1844
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1845
        out_stream = StringIO()
 
1846
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1847
        smart_protocol.accept_bytes('abc')
 
1848
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1849
        smart_protocol.accept_bytes('\n')
 
1850
        self.assertEqual(
 
1851
            self.response_marker +
 
1852
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1853
            out_stream.getvalue())
 
1854
        self.assertTrue(smart_protocol._has_dispatched)
 
1855
        self.assertEqual(0, smart_protocol.next_read_size())
 
1856
 
 
1857
    def test_accept_body_bytes_to_protocol(self):
 
1858
        protocol = self.build_protocol_waiting_for_body()
 
1859
        self.assertEqual(6, protocol.next_read_size())
 
1860
        protocol.accept_bytes('7\nabc')
 
1861
        self.assertEqual(9, protocol.next_read_size())
 
1862
        protocol.accept_bytes('defgd')
 
1863
        protocol.accept_bytes('one\n')
 
1864
        self.assertEqual(0, protocol.next_read_size())
 
1865
        self.assertTrue(self.end_received)
 
1866
 
 
1867
    def test_accept_request_and_body_all_at_once(self):
 
1868
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1869
        mem_transport = memory.MemoryTransport()
 
1870
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1871
        out_stream = StringIO()
 
1872
        smart_protocol = self.server_protocol_class(
 
1873
            mem_transport, out_stream.write)
 
1874
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1875
        self.assertEqual(0, smart_protocol.next_read_size())
 
1876
        self.assertEqual(self.response_marker +
 
1877
                         'success\nreadv\n3\ndefdone\n',
 
1878
                         out_stream.getvalue())
 
1879
        self.assertEqual('', smart_protocol.unused_data)
 
1880
        self.assertEqual('', smart_protocol.in_buffer)
 
1881
 
 
1882
    def test_accept_excess_bytes_are_preserved(self):
 
1883
        out_stream = StringIO()
 
1884
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1885
        smart_protocol.accept_bytes('hello\nhello\n')
 
1886
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1887
                         out_stream.getvalue())
 
1888
        self.assertEqual("hello\n", smart_protocol.unused_data)
 
1889
        self.assertEqual("", smart_protocol.in_buffer)
 
1890
 
 
1891
    def test_accept_excess_bytes_after_body(self):
 
1892
        # The excess bytes look like the start of another request.
 
1893
        server_protocol = self.build_protocol_waiting_for_body()
 
1894
        server_protocol.accept_bytes('7\nabcdefgdone\n' + self.response_marker)
 
1895
        self.assertTrue(self.end_received)
 
1896
        self.assertEqual(self.response_marker,
 
1897
                         server_protocol.unused_data)
 
1898
        self.assertEqual("", server_protocol.in_buffer)
 
1899
        server_protocol.accept_bytes('Y')
 
1900
        self.assertEqual(self.response_marker + "Y",
 
1901
                         server_protocol.unused_data)
 
1902
        self.assertEqual("", server_protocol.in_buffer)
 
1903
 
 
1904
    def test_accept_excess_bytes_after_dispatch(self):
 
1905
        out_stream = StringIO()
 
1906
        smart_protocol = self.server_protocol_class(None, out_stream.write)
 
1907
        smart_protocol.accept_bytes('hello\n')
 
1908
        self.assertEqual(self.response_marker + "success\nok\x012\n",
 
1909
                         out_stream.getvalue())
 
1910
        smart_protocol.accept_bytes(self.request_marker + 'hel')
 
1911
        self.assertEqual(self.request_marker + "hel",
 
1912
                         smart_protocol.unused_data)
 
1913
        smart_protocol.accept_bytes('lo\n')
 
1914
        self.assertEqual(self.request_marker + "hello\n",
 
1915
                         smart_protocol.unused_data)
 
1916
        self.assertEqual("", smart_protocol.in_buffer)
 
1917
 
 
1918
    def test__send_response_sets_finished_reading(self):
 
1919
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1920
        self.assertEqual(1, smart_protocol.next_read_size())
 
1921
        smart_protocol._send_response(
 
1922
            _mod_request.SuccessfulSmartServerResponse(('x',)))
 
1923
        self.assertEqual(0, smart_protocol.next_read_size())
 
1924
 
 
1925
    def test__send_response_errors_with_base_response(self):
 
1926
        """Ensure that only the Successful/Failed subclasses are used."""
 
1927
        smart_protocol = self.server_protocol_class(None, lambda x: None)
 
1928
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1929
            _mod_request.SmartServerResponse(('x',)))
 
1930
 
 
1931
    def test_query_version(self):
 
1932
        """query_version on a SmartClientProtocolTwo should return a number.
 
1933
        
 
1934
        The protocol provides the query_version because the domain level clients
 
1935
        may all need to be able to probe for capabilities.
 
1936
        """
 
1937
        # What we really want to test here is that SmartClientProtocolTwo calls
 
1938
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1939
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1940
        # the error if the response is a non-understood version.
 
1941
        input = StringIO(self.response_marker + 'success\nok\x012\n')
 
1942
        output = StringIO()
 
1943
        client_medium = medium.SmartSimplePipesClientMedium(
 
1944
            input, output, 'base')
 
1945
        request = client_medium.get_request()
 
1946
        smart_protocol = self.client_protocol_class(request)
 
1947
        self.assertEqual(2, smart_protocol.query_version())
 
1948
 
 
1949
    def test_client_call_empty_response(self):
 
1950
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1951
        # when the parsed line is an empty line, and results in a tuple with
 
1952
        # one element - an empty string.
 
1953
        self.assertServerToClientEncoding(
 
1954
            self.response_marker + 'success\n\n', ('', ), [(), ('', )])
 
1955
 
 
1956
    def test_client_call_three_element_response(self):
 
1957
        # protocol.call() can get back tuples of other lengths. A three element
 
1958
        # tuple should be unpacked as three strings.
 
1959
        self.assertServerToClientEncoding(
 
1960
            self.response_marker + 'success\na\x01b\x0134\n',
 
1961
            ('a', 'b', '34'),
 
1962
            [('a', 'b', '34')])
 
1963
 
 
1964
    def test_client_call_with_body_bytes_uploads(self):
 
1965
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1966
        # wire.
 
1967
        expected_bytes = self.request_marker + "foo\n7\nabcdefgdone\n"
 
1968
        input = StringIO("\n")
 
1969
        output = StringIO()
 
1970
        client_medium = medium.SmartSimplePipesClientMedium(
 
1971
            input, output, 'base')
 
1972
        request = client_medium.get_request()
 
1973
        smart_protocol = self.client_protocol_class(request)
 
1974
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1975
        self.assertEqual(expected_bytes, output.getvalue())
 
1976
 
 
1977
    def test_client_call_with_body_readv_array(self):
 
1978
        # protocol.call_with_upload should encode the readv array and then
 
1979
        # length-prefix the bytes onto the wire.
 
1980
        expected_bytes = self.request_marker + "foo\n7\n1,2\n5,6done\n"
 
1981
        input = StringIO("\n")
 
1982
        output = StringIO()
 
1983
        client_medium = medium.SmartSimplePipesClientMedium(
 
1984
            input, output, 'base')
 
1985
        request = client_medium.get_request()
 
1986
        smart_protocol = self.client_protocol_class(request)
 
1987
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1988
        self.assertEqual(expected_bytes, output.getvalue())
 
1989
 
 
1990
    def test_client_read_body_bytes_all(self):
 
1991
        # read_body_bytes should decode the body bytes from the wire into
 
1992
        # a response.
 
1993
        expected_bytes = "1234567"
 
1994
        server_bytes = (self.response_marker +
 
1995
                        "success\nok\n7\n1234567done\n")
 
1996
        input = StringIO(server_bytes)
 
1997
        output = StringIO()
 
1998
        client_medium = medium.SmartSimplePipesClientMedium(
 
1999
            input, output, 'base')
 
2000
        request = client_medium.get_request()
 
2001
        smart_protocol = self.client_protocol_class(request)
 
2002
        smart_protocol.call('foo')
 
2003
        smart_protocol.read_response_tuple(True)
 
2004
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
2005
 
 
2006
    def test_client_read_body_bytes_incremental(self):
 
2007
        # test reading a few bytes at a time from the body
 
2008
        # XXX: possibly we should test dribbling the bytes into the stringio
 
2009
        # to make the state machine work harder: however, as we use the
 
2010
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
2011
        # that.
 
2012
        expected_bytes = "1234567"
 
2013
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2014
        input = StringIO(server_bytes)
 
2015
        output = StringIO()
 
2016
        client_medium = medium.SmartSimplePipesClientMedium(
 
2017
            input, output, 'base')
 
2018
        request = client_medium.get_request()
 
2019
        smart_protocol = self.client_protocol_class(request)
 
2020
        smart_protocol.call('foo')
 
2021
        smart_protocol.read_response_tuple(True)
 
2022
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
2023
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
2024
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
2025
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
2026
 
 
2027
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
2028
        # cancelling the expected body needs to finish the request, but not
 
2029
        # read any more bytes.
 
2030
        server_bytes = self.response_marker + "success\nok\n7\n1234567done\n"
 
2031
        input = StringIO(server_bytes)
 
2032
        output = StringIO()
 
2033
        client_medium = medium.SmartSimplePipesClientMedium(
 
2034
            input, output, 'base')
 
2035
        request = client_medium.get_request()
 
2036
        smart_protocol = self.client_protocol_class(request)
 
2037
        smart_protocol.call('foo')
 
2038
        smart_protocol.read_response_tuple(True)
 
2039
        smart_protocol.cancel_read_body()
 
2040
        self.assertEqual(len(self.response_marker + 'success\nok\n'),
 
2041
                         input.tell())
 
2042
        self.assertRaises(
 
2043
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2044
 
 
2045
    def test_client_read_body_bytes_interrupted_connection(self):
 
2046
        server_bytes = (self.response_marker +
 
2047
                        "success\nok\n999\nincomplete body")
 
2048
        input = StringIO(server_bytes)
 
2049
        output = StringIO()
 
2050
        client_medium = medium.SmartSimplePipesClientMedium(
 
2051
            input, output, 'base')
 
2052
        request = client_medium.get_request()
 
2053
        smart_protocol = self.client_protocol_class(request)
 
2054
        smart_protocol.call('foo')
 
2055
        smart_protocol.read_response_tuple(True)
 
2056
        self.assertRaises(
 
2057
            errors.ConnectionReset, smart_protocol.read_body_bytes)
 
2058
 
 
2059
 
 
2060
class TestSmartProtocolTwoSpecificsMixin(object):
 
2061
 
 
2062
    def assertBodyStreamSerialisation(self, expected_serialisation,
 
2063
                                      body_stream):
 
2064
        """Assert that body_stream is serialised as expected_serialisation."""
 
2065
        out_stream = StringIO()
 
2066
        protocol._send_stream(body_stream, out_stream.write)
 
2067
        self.assertEqual(expected_serialisation, out_stream.getvalue())
 
2068
 
 
2069
    def assertBodyStreamRoundTrips(self, body_stream):
 
2070
        """Assert that body_stream is the same after being serialised and
 
2071
        deserialised.
 
2072
        """
 
2073
        out_stream = StringIO()
 
2074
        protocol._send_stream(body_stream, out_stream.write)
 
2075
        decoder = protocol.ChunkedBodyDecoder()
 
2076
        decoder.accept_bytes(out_stream.getvalue())
 
2077
        decoded_stream = list(iter(decoder.read_next_chunk, None))
 
2078
        self.assertEqual(body_stream, decoded_stream)
 
2079
 
 
2080
    def test_body_stream_serialisation_empty(self):
 
2081
        """A body_stream with no bytes can be serialised."""
 
2082
        self.assertBodyStreamSerialisation('chunked\nEND\n', [])
 
2083
        self.assertBodyStreamRoundTrips([])
 
2084
 
 
2085
    def test_body_stream_serialisation(self):
 
2086
        stream = ['chunk one', 'chunk two', 'chunk three']
 
2087
        self.assertBodyStreamSerialisation(
 
2088
            'chunked\n' + '9\nchunk one' + '9\nchunk two' + 'b\nchunk three' +
 
2089
            'END\n',
 
2090
            stream)
 
2091
        self.assertBodyStreamRoundTrips(stream)
 
2092
 
 
2093
    def test_body_stream_with_empty_element_serialisation(self):
 
2094
        """A body stream can include ''.
 
2095
 
 
2096
        The empty string can be transmitted like any other string.
 
2097
        """
 
2098
        stream = ['', 'chunk']
 
2099
        self.assertBodyStreamSerialisation(
 
2100
            'chunked\n' + '0\n' + '5\nchunk' + 'END\n', stream)
 
2101
        self.assertBodyStreamRoundTrips(stream)
 
2102
 
 
2103
    def test_body_stream_error_serialistion(self):
 
2104
        stream = ['first chunk',
 
2105
                  _mod_request.FailedSmartServerResponse(
 
2106
                      ('FailureName', 'failure arg'))]
 
2107
        expected_bytes = (
 
2108
            'chunked\n' + 'b\nfirst chunk' +
 
2109
            'ERR\n' + 'b\nFailureName' + 'b\nfailure arg' +
 
2110
            'END\n')
 
2111
        self.assertBodyStreamSerialisation(expected_bytes, stream)
 
2112
        self.assertBodyStreamRoundTrips(stream)
 
2113
 
 
2114
    def test__send_response_includes_failure_marker(self):
 
2115
        """FailedSmartServerResponse have 'failed\n' after the version."""
 
2116
        out_stream = StringIO()
 
2117
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
2118
            None, out_stream.write)
 
2119
        smart_protocol._send_response(
 
2120
            _mod_request.FailedSmartServerResponse(('x',)))
 
2121
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
 
2122
                         out_stream.getvalue())
 
2123
 
 
2124
    def test__send_response_includes_success_marker(self):
 
2125
        """SuccessfulSmartServerResponse have 'success\n' after the version."""
 
2126
        out_stream = StringIO()
 
2127
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
2128
            None, out_stream.write)
 
2129
        smart_protocol._send_response(
 
2130
            _mod_request.SuccessfulSmartServerResponse(('x',)))
 
2131
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
 
2132
                         out_stream.getvalue())
 
2133
 
 
2134
    def test__send_response_with_body_stream_sets_finished_reading(self):
 
2135
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
2136
            None, lambda x: None)
 
2137
        self.assertEqual(1, smart_protocol.next_read_size())
 
2138
        smart_protocol._send_response(
 
2139
            _mod_request.SuccessfulSmartServerResponse(('x',), body_stream=[]))
 
2140
        self.assertEqual(0, smart_protocol.next_read_size())
 
2141
 
 
2142
    def test_streamed_body_bytes(self):
 
2143
        body_header = 'chunked\n'
 
2144
        two_body_chunks = "4\n1234" + "3\n567"
 
2145
        body_terminator = "END\n"
 
2146
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
2147
                        "success\nok\n" + body_header + two_body_chunks +
 
2148
                        body_terminator)
 
2149
        input = StringIO(server_bytes)
 
2150
        output = StringIO()
 
2151
        client_medium = medium.SmartSimplePipesClientMedium(
 
2152
            input, output, 'base')
 
2153
        request = client_medium.get_request()
 
2154
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2155
        smart_protocol.call('foo')
 
2156
        smart_protocol.read_response_tuple(True)
 
2157
        stream = smart_protocol.read_streamed_body()
 
2158
        self.assertEqual(['1234', '567'], list(stream))
 
2159
 
 
2160
    def test_read_streamed_body_error(self):
 
2161
        """When a stream is interrupted by an error..."""
 
2162
        body_header = 'chunked\n'
 
2163
        a_body_chunk = '4\naaaa'
 
2164
        err_signal = 'ERR\n'
 
2165
        err_chunks = 'a\nerror arg1' + '4\narg2'
 
2166
        finish = 'END\n'
 
2167
        body = body_header + a_body_chunk + err_signal + err_chunks + finish
 
2168
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
2169
                        "success\nok\n" + body)
 
2170
        input = StringIO(server_bytes)
 
2171
        output = StringIO()
 
2172
        client_medium = medium.SmartSimplePipesClientMedium(
 
2173
            input, output, 'base')
 
2174
        smart_request = client_medium.get_request()
 
2175
        smart_protocol = protocol.SmartClientRequestProtocolTwo(smart_request)
 
2176
        smart_protocol.call('foo')
 
2177
        smart_protocol.read_response_tuple(True)
 
2178
        expected_chunks = [
 
2179
            'aaaa',
 
2180
            _mod_request.FailedSmartServerResponse(('error arg1', 'arg2'))]
 
2181
        stream = smart_protocol.read_streamed_body()
 
2182
        self.assertEqual(expected_chunks, list(stream))
 
2183
 
 
2184
    def test_streamed_body_bytes_interrupted_connection(self):
 
2185
        body_header = 'chunked\n'
 
2186
        incomplete_body_chunk = "9999\nincomplete chunk"
 
2187
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
2188
                        "success\nok\n" + body_header + incomplete_body_chunk)
 
2189
        input = StringIO(server_bytes)
 
2190
        output = StringIO()
 
2191
        client_medium = medium.SmartSimplePipesClientMedium(
 
2192
            input, output, 'base')
 
2193
        request = client_medium.get_request()
 
2194
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2195
        smart_protocol.call('foo')
 
2196
        smart_protocol.read_response_tuple(True)
 
2197
        stream = smart_protocol.read_streamed_body()
 
2198
        self.assertRaises(errors.ConnectionReset, stream.next)
 
2199
 
 
2200
    def test_client_read_response_tuple_sets_response_status(self):
 
2201
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
 
2202
        input = StringIO(server_bytes)
 
2203
        output = StringIO()
 
2204
        client_medium = medium.SmartSimplePipesClientMedium(
 
2205
            input, output, 'base')
 
2206
        request = client_medium.get_request()
 
2207
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2208
        smart_protocol.call('foo')
 
2209
        smart_protocol.read_response_tuple(False)
 
2210
        self.assertEqual(True, smart_protocol.response_status)
 
2211
 
 
2212
    def test_client_read_response_tuple_raises_UnknownSmartMethod(self):
 
2213
        """read_response_tuple raises UnknownSmartMethod if the response says
 
2214
        the server did not recognise the request.
 
2215
        """
 
2216
        server_bytes = (
 
2217
            protocol.RESPONSE_VERSION_TWO +
 
2218
            "failed\n" +
 
2219
            "error\x01Generic bzr smart protocol error: bad request 'foo'\n")
 
2220
        input = StringIO(server_bytes)
 
2221
        output = StringIO()
 
2222
        client_medium = medium.SmartSimplePipesClientMedium(
 
2223
            input, output, 'base')
 
2224
        request = client_medium.get_request()
 
2225
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
2226
        smart_protocol.call('foo')
 
2227
        self.assertRaises(
 
2228
            errors.UnknownSmartMethod, smart_protocol.read_response_tuple)
 
2229
        # The request has been finished.  There is no body to read, and
 
2230
        # attempts to read one will fail.
 
2231
        self.assertRaises(
 
2232
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
2233
 
 
2234
 
 
2235
class TestSmartProtocolTwoSpecifics(
 
2236
        TestSmartProtocol, TestSmartProtocolTwoSpecificsMixin):
 
2237
    """Tests for aspects of smart protocol version two that are unique to
 
2238
    version two.
 
2239
 
 
2240
    Thus tests involving body streams and success/failure markers belong here.
 
2241
    """
 
2242
 
 
2243
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
2244
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
2245
 
 
2246
 
 
2247
class TestVersionOneFeaturesInProtocolThree(
 
2248
    TestSmartProtocol, CommonSmartProtocolTestMixin):
 
2249
    """Tests for version one smart protocol features as implemented by version
 
2250
    three.
 
2251
    """
 
2252
 
 
2253
    request_encoder = protocol.ProtocolThreeRequester
 
2254
    response_decoder = protocol.ProtocolThreeDecoder
 
2255
    # build_server_protocol_three is a function, so we can't set it as a class
 
2256
    # attribute directly, because then Python will assume it is actually a
 
2257
    # method.  So we make server_protocol_class be a static method, rather than
 
2258
    # simply doing:
 
2259
    # "server_protocol_class = protocol.build_server_protocol_three".
 
2260
    server_protocol_class = staticmethod(protocol.build_server_protocol_three)
 
2261
 
 
2262
    def setUp(self):
 
2263
        super(TestVersionOneFeaturesInProtocolThree, self).setUp()
 
2264
        self.response_marker = protocol.MESSAGE_VERSION_THREE
 
2265
        self.request_marker = protocol.MESSAGE_VERSION_THREE
 
2266
 
 
2267
    def test_construct_version_three_server_protocol(self):
 
2268
        smart_protocol = protocol.ProtocolThreeDecoder(None)
 
2269
        self.assertEqual('', smart_protocol.unused_data)
 
2270
        self.assertEqual([], smart_protocol._in_buffer_list)
 
2271
        self.assertEqual(0, smart_protocol._in_buffer_len)
 
2272
        self.assertFalse(smart_protocol._has_dispatched)
 
2273
        # The protocol starts by expecting four bytes, a length prefix for the
 
2274
        # headers.
 
2275
        self.assertEqual(4, smart_protocol.next_read_size())
 
2276
 
 
2277
 
 
2278
class NoOpRequest(_mod_request.SmartServerRequest):
 
2279
 
 
2280
    def do(self):
 
2281
        return _mod_request.SuccessfulSmartServerResponse(())
 
2282
 
 
2283
dummy_registry = {'ARG': NoOpRequest}
 
2284
 
 
2285
 
 
2286
class LoggingMessageHandler(object):
 
2287
 
 
2288
    def __init__(self):
 
2289
        self.event_log = []
 
2290
 
 
2291
    def _log(self, *args):
 
2292
        self.event_log.append(args)
 
2293
 
 
2294
    def headers_received(self, headers):
 
2295
        self._log('headers', headers)
 
2296
 
 
2297
    def protocol_error(self, exception):
 
2298
        self._log('protocol_error', exception)
 
2299
 
 
2300
    def byte_part_received(self, byte):
 
2301
        self._log('byte', byte)
 
2302
 
 
2303
    def bytes_part_received(self, bytes):
 
2304
        self._log('bytes', bytes)
 
2305
 
 
2306
    def structure_part_received(self, structure):
 
2307
        self._log('structure', structure)
 
2308
 
 
2309
    def end_received(self):
 
2310
        self._log('end')
 
2311
 
 
2312
 
 
2313
class TestProtocolThree(TestSmartProtocol):
 
2314
    """Tests for v3 of the server-side protocol."""
 
2315
 
 
2316
    request_encoder = protocol.ProtocolThreeRequester
 
2317
    response_decoder = protocol.ProtocolThreeDecoder
 
2318
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2319
 
 
2320
    def test_trivial_request(self):
 
2321
        """Smoke test for the simplest possible v3 request: empty headers, no
 
2322
        message parts.
 
2323
        """
 
2324
        output = StringIO()
 
2325
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2326
        end = 'e'
 
2327
        request_bytes = headers + end
 
2328
        smart_protocol = self.server_protocol_class(LoggingMessageHandler())
 
2329
        smart_protocol.accept_bytes(request_bytes)
 
2330
        self.assertEqual(0, smart_protocol.next_read_size())
 
2331
        self.assertEqual('', smart_protocol.unused_data)
 
2332
 
 
2333
    def make_protocol_expecting_message_part(self):
 
2334
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2335
        message_handler = LoggingMessageHandler()
 
2336
        smart_protocol = self.server_protocol_class(message_handler)
 
2337
        smart_protocol.accept_bytes(headers)
 
2338
        # Clear the event log
 
2339
        del message_handler.event_log[:]
 
2340
        return smart_protocol, message_handler.event_log
 
2341
 
 
2342
    def test_decode_one_byte(self):
 
2343
        """The protocol can decode a 'one byte' message part."""
 
2344
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2345
        smart_protocol.accept_bytes('ox')
 
2346
        self.assertEqual([('byte', 'x')], event_log)
 
2347
 
 
2348
    def test_decode_bytes(self):
 
2349
        """The protocol can decode a 'bytes' message part."""
 
2350
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2351
        smart_protocol.accept_bytes(
 
2352
            'b' # message part kind
 
2353
            '\0\0\0\x07' # length prefix
 
2354
            'payload' # payload
 
2355
            )
 
2356
        self.assertEqual([('bytes', 'payload')], event_log)
 
2357
 
 
2358
    def test_decode_structure(self):
 
2359
        """The protocol can decode a 'structure' message part."""
 
2360
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2361
        smart_protocol.accept_bytes(
 
2362
            's' # message part kind
 
2363
            '\0\0\0\x07' # length prefix
 
2364
            'l3:ARGe' # ['ARG']
 
2365
            )
 
2366
        self.assertEqual([('structure', ['ARG'])], event_log)
 
2367
 
 
2368
    def test_decode_multiple_bytes(self):
 
2369
        """The protocol can decode a multiple 'bytes' message parts."""
 
2370
        smart_protocol, event_log = self.make_protocol_expecting_message_part()
 
2371
        smart_protocol.accept_bytes(
 
2372
            'b' # message part kind
 
2373
            '\0\0\0\x05' # length prefix
 
2374
            'first' # payload
 
2375
            'b' # message part kind
 
2376
            '\0\0\0\x06'
 
2377
            'second'
 
2378
            )
 
2379
        self.assertEqual(
 
2380
            [('bytes', 'first'), ('bytes', 'second')], event_log)
 
2381
 
 
2382
 
 
2383
class TestConventionalResponseHandler(tests.TestCase):
 
2384
 
 
2385
    def make_response_handler(self, response_bytes):
 
2386
        from bzrlib.smart.message import ConventionalResponseHandler
 
2387
        response_handler = ConventionalResponseHandler()
 
2388
        protocol_decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2389
        # put decoder in desired state (waiting for message parts)
 
2390
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
 
2391
        output = StringIO()
 
2392
        client_medium = medium.SmartSimplePipesClientMedium(
 
2393
            StringIO(response_bytes), output, 'base')
 
2394
        medium_request = client_medium.get_request()
 
2395
        medium_request.finished_writing()
 
2396
        response_handler.setProtoAndMediumRequest(
 
2397
            protocol_decoder, medium_request)
 
2398
        return response_handler
 
2399
 
 
2400
    def test_body_stream_interrupted_by_error(self):
 
2401
        interrupted_body_stream = (
 
2402
            'oS' # successful response
 
2403
            's\0\0\0\x02le' # empty args
 
2404
            'b\0\0\0\x09chunk one' # first chunk
 
2405
            'b\0\0\0\x09chunk two' # second chunk
 
2406
            'oE' # error flag
 
2407
            's\0\0\0\x0el5:error3:abce' # bencoded error
 
2408
            'e' # message end
 
2409
            )
 
2410
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2411
        stream = response_handler.read_streamed_body()
 
2412
        self.assertEqual('chunk one', stream.next())
 
2413
        self.assertEqual('chunk two', stream.next())
 
2414
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
 
2415
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2416
 
 
2417
    def test_body_stream_interrupted_by_connection_lost(self):
 
2418
        interrupted_body_stream = (
 
2419
            'oS' # successful response
 
2420
            's\0\0\0\x02le' # empty args
 
2421
            'b\0\0\xff\xffincomplete chunk')
 
2422
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2423
        stream = response_handler.read_streamed_body()
 
2424
        self.assertRaises(errors.ConnectionReset, stream.next)
 
2425
 
 
2426
    def test_read_body_bytes_interrupted_by_connection_lost(self):
 
2427
        interrupted_body_stream = (
 
2428
            'oS' # successful response
 
2429
            's\0\0\0\x02le' # empty args
 
2430
            'b\0\0\xff\xffincomplete chunk')
 
2431
        response_handler = self.make_response_handler(interrupted_body_stream)
 
2432
        self.assertRaises(
 
2433
            errors.ConnectionReset, response_handler.read_body_bytes)
 
2434
 
 
2435
 
 
2436
class TestMessageHandlerErrors(tests.TestCase):
 
2437
    """Tests for v3 that unrecognised (but well-formed) requests/responses are
 
2438
    still fully read off the wire, so that subsequent requests/responses on the
 
2439
    same medium can be decoded.
 
2440
    """
 
2441
 
 
2442
    def test_non_conventional_request(self):
 
2443
        """ConventionalRequestHandler (the default message handler on the
 
2444
        server side) will reject an unconventional message, but still consume
 
2445
        all the bytes of that message and signal when it has done so.
 
2446
 
 
2447
        This is what allows a server to continue to accept requests after the
 
2448
        client sends a completely unrecognised request.
 
2449
        """
 
2450
        # Define an invalid request (but one that is a well-formed message).
 
2451
        # This particular invalid request not only lacks the mandatory
 
2452
        # verb+args tuple, it has a single-byte part, which is forbidden.  In
 
2453
        # fact it has that part twice, to trigger multiple errors.
 
2454
        invalid_request = (
 
2455
            protocol.MESSAGE_VERSION_THREE +  # protocol version marker
 
2456
            '\0\0\0\x02de' + # empty headers
 
2457
            'oX' + # a single byte part: 'X'.  ConventionalRequestHandler will
 
2458
                   # error at this part.
 
2459
            'oX' + # and again.
 
2460
            'e' # end of message
 
2461
            )
 
2462
 
 
2463
        to_server = StringIO(invalid_request)
 
2464
        from_server = StringIO()
 
2465
        transport = memory.MemoryTransport('memory:///')
 
2466
        server = medium.SmartServerPipeStreamMedium(
 
2467
            to_server, from_server, transport)
 
2468
        proto = server._build_protocol()
 
2469
        message_handler = proto.message_handler
 
2470
        server._serve_one_request(proto)
 
2471
        # All the bytes have been read from the medium...
 
2472
        self.assertEqual('', to_server.read())
 
2473
        # ...and the protocol decoder has consumed all the bytes, and has
 
2474
        # finished reading.
 
2475
        self.assertEqual('', proto.unused_data)
 
2476
        self.assertEqual(0, proto.next_read_size())
 
2477
 
 
2478
 
 
2479
class InstrumentedRequestHandler(object):
 
2480
    """Test Double of SmartServerRequestHandler."""
 
2481
 
 
2482
    def __init__(self):
 
2483
        self.calls = []
 
2484
 
 
2485
    def body_chunk_received(self, chunk_bytes):
 
2486
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2487
 
 
2488
    def no_body_received(self):
 
2489
        self.calls.append(('no_body_received',))
 
2490
 
 
2491
    def prefixed_body_received(self, body_bytes):
 
2492
        self.calls.append(('prefixed_body_received', body_bytes))
 
2493
 
 
2494
    def end_received(self):
 
2495
        self.calls.append(('end_received',))
 
2496
 
 
2497
 
 
2498
class StubRequest(object):
 
2499
 
 
2500
    def finished_reading(self):
 
2501
        pass
 
2502
 
 
2503
 
 
2504
class TestClientDecodingProtocolThree(TestSmartProtocol):
 
2505
    """Tests for v3 of the client-side protocol decoding."""
 
2506
 
 
2507
    def make_logging_response_decoder(self):
 
2508
        """Make v3 response decoder using a test response handler."""
 
2509
        response_handler = LoggingMessageHandler()
 
2510
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2511
        return decoder, response_handler
 
2512
 
 
2513
    def make_conventional_response_decoder(self):
 
2514
        """Make v3 response decoder using a conventional response handler."""
 
2515
        response_handler = message.ConventionalResponseHandler()
 
2516
        decoder = protocol.ProtocolThreeDecoder(response_handler)
 
2517
        response_handler.setProtoAndMediumRequest(decoder, StubRequest())
 
2518
        return decoder, response_handler
 
2519
 
 
2520
    def test_trivial_response_decoding(self):
 
2521
        """Smoke test for the simplest possible v3 response: empty headers,
 
2522
        status byte, empty args, no body.
 
2523
        """
 
2524
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2525
        response_status = 'oS' # success
 
2526
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2527
        end = 'e' # end marker
 
2528
        message_bytes = headers + response_status + args + end
 
2529
        decoder, response_handler = self.make_logging_response_decoder()
 
2530
        decoder.accept_bytes(message_bytes)
 
2531
        # The protocol decoder has finished, and consumed all bytes
 
2532
        self.assertEqual(0, decoder.next_read_size())
 
2533
        self.assertEqual('', decoder.unused_data)
 
2534
        # The message handler has been invoked with all the parts of the
 
2535
        # trivial response: empty headers, status byte, no args, end.
 
2536
        self.assertEqual(
 
2537
            [('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
 
2538
            response_handler.event_log)
 
2539
 
 
2540
    def test_incomplete_message(self):
 
2541
        """A decoder will keep signalling that it needs more bytes via
 
2542
        next_read_size() != 0 until it has seen a complete message, regardless
 
2543
        which state it is in.
 
2544
        """
 
2545
        # Define a simple response that uses all possible message parts.
 
2546
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2547
        response_status = 'oS' # success
 
2548
        args = 's\0\0\0\x02le' # length-prefixed, bencoded empty list
 
2549
        body = 'b\0\0\0\x04BODY' # a body: 'BODY'
 
2550
        end = 'e' # end marker
 
2551
        simple_response = headers + response_status + args + body + end
 
2552
        # Feed the request to the decoder one byte at a time.
 
2553
        decoder, response_handler = self.make_logging_response_decoder()
 
2554
        for byte in simple_response:
 
2555
            self.assertNotEqual(0, decoder.next_read_size())
 
2556
            decoder.accept_bytes(byte)
 
2557
        # Now the response is complete
 
2558
        self.assertEqual(0, decoder.next_read_size())
 
2559
 
 
2560
    def test_read_response_tuple_raises_UnknownSmartMethod(self):
 
2561
        """read_response_tuple raises UnknownSmartMethod if the server replied
 
2562
        with 'UnknownMethod'.
 
2563
        """
 
2564
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2565
        response_status = 'oE' # error flag
 
2566
        # args: ('UnknownMethod', 'method-name')
 
2567
        args = 's\0\0\0\x20l13:UnknownMethod11:method-namee'
 
2568
        end = 'e' # end marker
 
2569
        message_bytes = headers + response_status + args + end
 
2570
        decoder, response_handler = self.make_conventional_response_decoder()
 
2571
        decoder.accept_bytes(message_bytes)
 
2572
        error = self.assertRaises(
 
2573
            errors.UnknownSmartMethod, response_handler.read_response_tuple)
 
2574
        self.assertEqual('method-name', error.verb)
 
2575
 
 
2576
    def test_read_response_tuple_error(self):
 
2577
        """If the response has an error, it is raised as an exception."""
 
2578
        headers = '\0\0\0\x02de'  # length-prefixed, bencoded empty dict
 
2579
        response_status = 'oE' # error
 
2580
        args = 's\0\0\0\x1al9:first arg10:second arge' # two args
 
2581
        end = 'e' # end marker
 
2582
        message_bytes = headers + response_status + args + end
 
2583
        decoder, response_handler = self.make_conventional_response_decoder()
 
2584
        decoder.accept_bytes(message_bytes)
 
2585
        error = self.assertRaises(
 
2586
            errors.ErrorFromSmartServer, response_handler.read_response_tuple)
 
2587
        self.assertEqual(('first arg', 'second arg'), error.error_tuple)
 
2588
 
 
2589
 
 
2590
class TestClientEncodingProtocolThree(TestSmartProtocol):
 
2591
 
 
2592
    request_encoder = protocol.ProtocolThreeRequester
 
2593
    response_decoder = protocol.ProtocolThreeDecoder
 
2594
    server_protocol_class = protocol.ProtocolThreeDecoder
 
2595
 
 
2596
    def make_client_encoder_and_output(self):
 
2597
        result = self.make_client_protocol_and_output()
 
2598
        requester, response_handler, output = result
 
2599
        return requester, output
 
2600
 
 
2601
    def test_call_smoke_test(self):
 
2602
        """A smoke test for ProtocolThreeRequester.call.
 
2603
 
 
2604
        This test checks that a particular simple invocation of call emits the
 
2605
        correct bytes for that invocation.
 
2606
        """
 
2607
        requester, output = self.make_client_encoder_and_output()
 
2608
        requester.set_headers({'header name': 'header value'})
 
2609
        requester.call('one arg')
 
2610
        self.assertEquals(
 
2611
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2612
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2613
            's\x00\x00\x00\x0bl7:one arge' # args
 
2614
            'e', # end
 
2615
            output.getvalue())
 
2616
 
 
2617
    def test_call_with_body_bytes_smoke_test(self):
 
2618
        """A smoke test for ProtocolThreeRequester.call_with_body_bytes.
 
2619
 
 
2620
        This test checks that a particular simple invocation of
 
2621
        call_with_body_bytes emits the correct bytes for that invocation.
 
2622
        """
 
2623
        requester, output = self.make_client_encoder_and_output()
 
2624
        requester.set_headers({'header name': 'header value'})
 
2625
        requester.call_with_body_bytes(('one arg',), 'body bytes')
 
2626
        self.assertEquals(
 
2627
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2628
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2629
            's\x00\x00\x00\x0bl7:one arge' # args
 
2630
            'b' # there is a prefixed body
 
2631
            '\x00\x00\x00\nbody bytes' # the prefixed body
 
2632
            'e', # end
 
2633
            output.getvalue())
 
2634
 
 
2635
    def test_call_writes_just_once(self):
 
2636
        """A bodyless request is written to the medium all at once."""
 
2637
        medium_request = StubMediumRequest()
 
2638
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2639
        encoder.call('arg1', 'arg2', 'arg3')
 
2640
        self.assertEqual(
 
2641
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2642
 
 
2643
    def test_call_with_body_bytes_writes_just_once(self):
 
2644
        """A request with body bytes is written to the medium all at once."""
 
2645
        medium_request = StubMediumRequest()
 
2646
        encoder = protocol.ProtocolThreeRequester(medium_request)
 
2647
        encoder.call_with_body_bytes(('arg', 'arg'), 'body bytes')
 
2648
        self.assertEqual(
 
2649
            ['accept_bytes', 'finished_writing'], medium_request.calls)
 
2650
 
 
2651
 
 
2652
class StubMediumRequest(object):
 
2653
    """A stub medium request that tracks the number of times accept_bytes is
 
2654
    called.
 
2655
    """
 
2656
 
 
2657
    def __init__(self):
 
2658
        self.calls = []
 
2659
        self._medium = 'dummy medium'
 
2660
 
 
2661
    def accept_bytes(self, bytes):
 
2662
        self.calls.append('accept_bytes')
 
2663
 
 
2664
    def finished_writing(self):
 
2665
        self.calls.append('finished_writing')
 
2666
 
 
2667
 
 
2668
class TestResponseEncodingProtocolThree(tests.TestCase):
 
2669
 
 
2670
    def make_response_encoder(self):
 
2671
        out_stream = StringIO()
 
2672
        response_encoder = protocol.ProtocolThreeResponder(out_stream.write)
 
2673
        return response_encoder, out_stream
 
2674
 
 
2675
    def test_send_error_unknown_method(self):
 
2676
        encoder, out_stream = self.make_response_encoder()
 
2677
        encoder.send_error(errors.UnknownSmartMethod('method name'))
 
2678
        # Use assertEndsWith so that we don't compare the header, which varies
 
2679
        # by bzrlib.__version__.
 
2680
        self.assertEndsWith(
 
2681
            out_stream.getvalue(),
 
2682
            # error status
 
2683
            'oE' +
 
2684
            # tuple: 'UnknownMethod', 'method name'
 
2685
            's\x00\x00\x00\x20l13:UnknownMethod11:method namee'
 
2686
            # end of message
 
2687
            'e')
 
2688
 
 
2689
 
 
2690
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
 
2691
    """Tests for buffering of responses.
 
2692
 
 
2693
    We want to avoid doing many small writes when one would do, to avoid
 
2694
    unnecessary network overhead.
 
2695
    """
 
2696
 
 
2697
    def setUp(self):
 
2698
        self.writes = []
 
2699
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
 
2700
 
 
2701
    def assertWriteCount(self, expected_count):
 
2702
        self.assertEqual(
 
2703
            expected_count, len(self.writes),
 
2704
            "Too many writes: %r" % (self.writes,))
 
2705
        
 
2706
    def test_send_error_writes_just_once(self):
 
2707
        """An error response is written to the medium all at once."""
 
2708
        self.responder.send_error(Exception('An exception string.'))
 
2709
        self.assertWriteCount(1)
 
2710
 
 
2711
    def test_send_response_writes_just_once(self):
 
2712
        """A normal response with no body is written to the medium all at once.
 
2713
        """
 
2714
        response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
 
2715
        self.responder.send_response(response)
 
2716
        self.assertWriteCount(1)
 
2717
 
 
2718
    def test_send_response_with_body_writes_just_once(self):
 
2719
        """A normal response with a monolithic body is written to the medium
 
2720
        all at once.
 
2721
        """
 
2722
        response = _mod_request.SuccessfulSmartServerResponse(
 
2723
            ('arg', 'arg'), body='body bytes')
 
2724
        self.responder.send_response(response)
 
2725
        self.assertWriteCount(1)
 
2726
 
 
2727
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
 
2728
        """A normal response with a stream body is written to the medium
 
2729
        writes to the medium once per chunk.
 
2730
        """
 
2731
        # Construct a response with stream with 2 chunks in it.
 
2732
        response = _mod_request.SuccessfulSmartServerResponse(
 
2733
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
 
2734
        self.responder.send_response(response)
 
2735
        # We will write 3 times: exactly once for each chunk, plus a final
 
2736
        # write to end the response.
 
2737
        self.assertWriteCount(3)
 
2738
 
 
2739
 
 
2740
class TestSmartClientUnicode(tests.TestCase):
 
2741
    """_SmartClient tests for unicode arguments.
 
2742
 
 
2743
    Unicode arguments to call_with_body_bytes are not correct (remote method
 
2744
    names, arguments, and bodies must all be expressed as byte strings), but
 
2745
    _SmartClient should gracefully reject them, rather than getting into a
 
2746
    broken state that prevents future correct calls from working.  That is, it
 
2747
    should be possible to issue more requests on the medium afterwards, rather
 
2748
    than allowing one bad call to call_with_body_bytes to cause later calls to
 
2749
    mysteriously fail with TooManyConcurrentRequests.
 
2750
    """
 
2751
 
 
2752
    def assertCallDoesNotBreakMedium(self, method, args, body):
 
2753
        """Call a medium with the given method, args and body, then assert that
 
2754
        the medium is left in a sane state, i.e. is capable of allowing further
 
2755
        requests.
 
2756
        """
 
2757
        input = StringIO("\n")
 
2758
        output = StringIO()
 
2759
        client_medium = medium.SmartSimplePipesClientMedium(
 
2760
            input, output, 'ignored base')
 
2761
        smart_client = client._SmartClient(client_medium)
 
2762
        self.assertRaises(TypeError,
 
2763
            smart_client.call_with_body_bytes, method, args, body)
 
2764
        self.assertEqual("", output.getvalue())
 
2765
        self.assertEqual(None, client_medium._current_request)
 
2766
 
 
2767
    def test_call_with_body_bytes_unicode_method(self):
 
2768
        self.assertCallDoesNotBreakMedium(u'method', ('args',), 'body')
 
2769
 
 
2770
    def test_call_with_body_bytes_unicode_args(self):
 
2771
        self.assertCallDoesNotBreakMedium('method', (u'args',), 'body')
 
2772
        self.assertCallDoesNotBreakMedium('method', ('arg1', u'arg2'), 'body')
 
2773
 
 
2774
    def test_call_with_body_bytes_unicode_body(self):
 
2775
        self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
 
2776
 
 
2777
 
 
2778
class MockMedium(medium.SmartClientMedium):
 
2779
    """A mock medium that can be used to test _SmartClient.
 
2780
    
 
2781
    It can be given a series of requests to expect (and responses it should
 
2782
    return for them).  It can also be told when the client is expected to
 
2783
    disconnect a medium.  Expectations must be satisfied in the order they are
 
2784
    given, or else an AssertionError will be raised.
 
2785
 
 
2786
    Typical use looks like::
 
2787
 
 
2788
        medium = MockMedium()
 
2789
        medium.expect_request(...)
 
2790
        medium.expect_request(...)
 
2791
        medium.expect_request(...)
 
2792
    """
 
2793
 
 
2794
    def __init__(self):
 
2795
        super(MockMedium, self).__init__('dummy base')
 
2796
        self._mock_request = _MockMediumRequest(self)
 
2797
        self._expected_events = []
 
2798
        
 
2799
    def expect_request(self, request_bytes, response_bytes,
 
2800
                       allow_partial_read=False):
 
2801
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
 
2802
 
 
2803
        No assumption is made about how many times accept_bytes should be
 
2804
        called to send the request.  Similarly, no assumption is made about how
 
2805
        many times read_bytes/read_line are called by protocol code to read a
 
2806
        response.  e.g.::
 
2807
        
 
2808
            request.accept_bytes('ab')
 
2809
            request.accept_bytes('cd')
 
2810
            request.finished_writing()
 
2811
 
 
2812
        and::
 
2813
        
 
2814
            request.accept_bytes('abcd')
 
2815
            request.finished_writing()
 
2816
 
 
2817
        Will both satisfy ``medium.expect_request('abcd', ...)``.  Thus tests
 
2818
        using this should not break due to irrelevant changes in protocol
 
2819
        implementations.
 
2820
 
 
2821
        :param allow_partial_read: if True, no assertion is raised if a
 
2822
            response is not fully read.  Setting this is useful when the client
 
2823
            is expected to disconnect without needing to read the complete
 
2824
            response.  Default is False.
 
2825
        """
 
2826
        self._expected_events.append(('send request', request_bytes))
 
2827
        if allow_partial_read:
 
2828
            self._expected_events.append(
 
2829
                ('read response (partial)', response_bytes))
 
2830
        else:
 
2831
            self._expected_events.append(('read response', response_bytes))
 
2832
 
 
2833
    def expect_disconnect(self):
 
2834
        """Expect the client to call ``medium.disconnect()``."""
 
2835
        self._expected_events.append('disconnect')
 
2836
 
 
2837
    def _assertEvent(self, observed_event):
 
2838
        """Raise AssertionError unless observed_event matches the next expected
 
2839
        event.
 
2840
 
 
2841
        :seealso: expect_request
 
2842
        :seealso: expect_disconnect
 
2843
        """
 
2844
        try:
 
2845
            expected_event = self._expected_events.pop(0)
 
2846
        except IndexError:
 
2847
            raise AssertionError(
 
2848
                'Mock medium observed event %r, but no more events expected'
 
2849
                % (observed_event,))
 
2850
        if expected_event[0] == 'read response (partial)':
 
2851
            if observed_event[0] != 'read response':
 
2852
                raise AssertionError(
 
2853
                    'Mock medium observed event %r, but expected event %r'
 
2854
                    % (observed_event, expected_event))
 
2855
        elif observed_event != expected_event:
 
2856
            raise AssertionError(
 
2857
                'Mock medium observed event %r, but expected event %r'
 
2858
                % (observed_event, expected_event))
 
2859
        if self._expected_events:
 
2860
            next_event = self._expected_events[0]
 
2861
            if next_event[0].startswith('read response'):
 
2862
                self._mock_request._response = next_event[1]
 
2863
 
 
2864
    def get_request(self):
 
2865
        return self._mock_request
 
2866
 
 
2867
    def disconnect(self):
 
2868
        if self._mock_request._read_bytes:
 
2869
            self._assertEvent(('read response', self._mock_request._read_bytes))
 
2870
            self._mock_request._read_bytes = ''
 
2871
        self._assertEvent('disconnect')
 
2872
 
 
2873
 
 
2874
class _MockMediumRequest(object):
 
2875
    """A mock ClientMediumRequest used by MockMedium."""
 
2876
 
 
2877
    def __init__(self, mock_medium):
 
2878
        self._medium = mock_medium
 
2879
        self._written_bytes = ''
 
2880
        self._read_bytes = ''
 
2881
        self._response = None
 
2882
 
 
2883
    def accept_bytes(self, bytes):
 
2884
        self._written_bytes += bytes
 
2885
 
 
2886
    def finished_writing(self):
 
2887
        self._medium._assertEvent(('send request', self._written_bytes))
 
2888
        self._written_bytes = ''
 
2889
 
 
2890
    def finished_reading(self):
 
2891
        self._medium._assertEvent(('read response', self._read_bytes))
 
2892
        self._read_bytes = ''
 
2893
 
 
2894
    def read_bytes(self, size):
 
2895
        resp = self._response
 
2896
        bytes, resp = resp[:size], resp[size:]
 
2897
        self._response = resp
 
2898
        self._read_bytes += bytes
 
2899
        return bytes
 
2900
 
 
2901
    def read_line(self):
 
2902
        resp = self._response
 
2903
        try:
 
2904
            line, resp = resp.split('\n', 1)
 
2905
            line += '\n'
 
2906
        except ValueError:
 
2907
            line, resp = resp, ''
 
2908
        self._response = resp
 
2909
        self._read_bytes += line
 
2910
        return line
 
2911
 
 
2912
 
 
2913
class Test_SmartClientVersionDetection(tests.TestCase):
 
2914
    """Tests for _SmartClient's automatic protocol version detection.
 
2915
 
 
2916
    On the first remote call, _SmartClient will keep retrying the request with
 
2917
    different protocol versions until it finds one that works.
 
2918
    """
 
2919
 
 
2920
    def test_version_three_server(self):
 
2921
        """With a protocol 3 server, only one request is needed."""
 
2922
        medium = MockMedium()
 
2923
        smart_client = client._SmartClient(medium, headers={})
 
2924
        message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
 
2925
        medium.expect_request(
 
2926
            message_start +
 
2927
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2928
            message_start + 's\0\0\0\x13l14:response valueee')
 
2929
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2930
        # The call succeeded without raising any exceptions from the mock
 
2931
        # medium, and the smart_client returns the response from the server.
 
2932
        self.assertEqual(('response value',), result)
 
2933
        self.assertEqual([], medium._expected_events)
 
2934
        # Also, the v3 works then the server should be assumed to support RPCs
 
2935
        # introduced in 1.6.
 
2936
        self.assertFalse(medium._is_remote_before((1, 6)))
 
2937
 
 
2938
    def test_version_two_server(self):
 
2939
        """If the server only speaks protocol 2, the client will first try
 
2940
        version 3, then fallback to protocol 2.
 
2941
 
 
2942
        Further, _SmartClient caches the detection, so future requests will all
 
2943
        use protocol 2 immediately.
 
2944
        """
 
2945
        medium = MockMedium()
 
2946
        smart_client = client._SmartClient(medium, headers={})
 
2947
        # First the client should send a v3 request, but the server will reply
 
2948
        # with a v2 error.
 
2949
        medium.expect_request(
 
2950
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2951
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2952
            'bzr response 2\nfailed\n\n')
 
2953
        # So then the client should disconnect to reset the connection, because
 
2954
        # the client needs to assume the server cannot read any further
 
2955
        # requests off the original connection.
 
2956
        medium.expect_disconnect()
 
2957
        # The client should then retry the original request in v2
 
2958
        medium.expect_request(
 
2959
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2960
            'bzr response 2\nsuccess\nresponse value\n')
 
2961
        result = smart_client.call('method-name', 'arg 1', 'arg 2')
 
2962
        # The smart_client object will return the result of the successful
 
2963
        # query.
 
2964
        self.assertEqual(('response value',), result)
 
2965
 
 
2966
        # Now try another request, and this time the client will just use
 
2967
        # protocol 2.  (i.e. the autodetection won't be repeated)
 
2968
        medium.expect_request(
 
2969
            'bzr request 2\nanother-method\n',
 
2970
            'bzr response 2\nsuccess\nanother response\n')
 
2971
        result = smart_client.call('another-method')
 
2972
        self.assertEqual(('another response',), result)
 
2973
        self.assertEqual([], medium._expected_events)
 
2974
 
 
2975
        # Also, because v3 is not supported, the client medium should assume
 
2976
        # that RPCs introduced in 1.6 aren't supported either.
 
2977
        self.assertTrue(medium._is_remote_before((1, 6)))
 
2978
 
 
2979
    def test_unknown_version(self):
 
2980
        """If the server does not use any known (or at least supported)
 
2981
        protocol version, a SmartProtocolError is raised.
 
2982
        """
 
2983
        medium = MockMedium()
 
2984
        smart_client = client._SmartClient(medium, headers={})
 
2985
        unknown_protocol_bytes = 'Unknown protocol!'
 
2986
        # The client will try v3 and v2 before eventually giving up.
 
2987
        medium.expect_request(
 
2988
            'bzr message 3 (bzr 1.6)\n\x00\x00\x00\x02de' +
 
2989
            's\x00\x00\x00\x1el11:method-name5:arg 15:arg 2ee',
 
2990
            unknown_protocol_bytes)
 
2991
        medium.expect_disconnect()
 
2992
        medium.expect_request(
 
2993
            'bzr request 2\nmethod-name\x01arg 1\x01arg 2\n',
 
2994
            unknown_protocol_bytes)
 
2995
        medium.expect_disconnect()
 
2996
        self.assertRaises(
 
2997
            errors.SmartProtocolError,
 
2998
            smart_client.call, 'method-name', 'arg 1', 'arg 2')
 
2999
        self.assertEqual([], medium._expected_events)
 
3000
 
 
3001
    def test_first_response_is_error(self):
 
3002
        """If the server replies with an error, then the version detection
 
3003
        should be complete.
 
3004
        
 
3005
        This test is very similar to test_version_two_server, but catches a bug
 
3006
        we had in the case where the first reply was an error response.
 
3007
        """
 
3008
        medium = MockMedium()
 
3009
        smart_client = client._SmartClient(medium, headers={})
 
3010
        message_start = protocol.MESSAGE_VERSION_THREE + '\x00\x00\x00\x02de'
 
3011
        # Issue a request that gets an error reply in a non-default protocol
 
3012
        # version.
 
3013
        medium.expect_request(
 
3014
            message_start +
 
3015
            's\x00\x00\x00\x10l11:method-nameee',
 
3016
            'bzr response 2\nfailed\n\n')
 
3017
        medium.expect_disconnect()
 
3018
        medium.expect_request(
 
3019
            'bzr request 2\nmethod-name\n',
 
3020
            'bzr response 2\nfailed\nFooBarError\n')
 
3021
        err = self.assertRaises(
 
3022
            errors.ErrorFromSmartServer,
 
3023
            smart_client.call, 'method-name')
 
3024
        self.assertEqual(('FooBarError',), err.error_tuple)
 
3025
        # Now the medium should have remembered the protocol version, so
 
3026
        # subsequent requests will use the remembered version immediately.
 
3027
        medium.expect_request(
 
3028
            'bzr request 2\nmethod-name\n',
 
3029
            'bzr response 2\nsuccess\nresponse value\n')
 
3030
        result = smart_client.call('method-name')
 
3031
        self.assertEqual(('response value',), result)
 
3032
        self.assertEqual([], medium._expected_events)
 
3033
 
 
3034
 
 
3035
class Test_SmartClient(tests.TestCase):
 
3036
 
 
3037
    def test_call_default_headers(self):
 
3038
        """ProtocolThreeRequester.call by default sends a 'Software
 
3039
        version' header.
 
3040
        """
 
3041
        smart_client = client._SmartClient('dummy medium')
 
3042
        self.assertEqual(
 
3043
            bzrlib.__version__, smart_client._headers['Software version'])
 
3044
        # XXX: need a test that smart_client._headers is passed to the request
 
3045
        # encoder.
 
3046
 
 
3047
 
 
3048
class LengthPrefixedBodyDecoder(tests.TestCase):
 
3049
 
 
3050
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
 
3051
    # something similar to the ProtocolBase method.
 
3052
 
 
3053
    def test_construct(self):
 
3054
        decoder = protocol.LengthPrefixedBodyDecoder()
 
3055
        self.assertFalse(decoder.finished_reading)
 
3056
        self.assertEqual(6, decoder.next_read_size())
 
3057
        self.assertEqual('', decoder.read_pending_data())
 
3058
        self.assertEqual('', decoder.unused_data)
 
3059
 
 
3060
    def test_accept_bytes(self):
 
3061
        decoder = protocol.LengthPrefixedBodyDecoder()
 
3062
        decoder.accept_bytes('')
 
3063
        self.assertFalse(decoder.finished_reading)
 
3064
        self.assertEqual(6, decoder.next_read_size())
 
3065
        self.assertEqual('', decoder.read_pending_data())
 
3066
        self.assertEqual('', decoder.unused_data)
 
3067
        decoder.accept_bytes('7')
 
3068
        self.assertFalse(decoder.finished_reading)
 
3069
        self.assertEqual(6, decoder.next_read_size())
 
3070
        self.assertEqual('', decoder.read_pending_data())
 
3071
        self.assertEqual('', decoder.unused_data)
 
3072
        decoder.accept_bytes('\na')
 
3073
        self.assertFalse(decoder.finished_reading)
 
3074
        self.assertEqual(11, decoder.next_read_size())
 
3075
        self.assertEqual('a', decoder.read_pending_data())
 
3076
        self.assertEqual('', decoder.unused_data)
 
3077
        decoder.accept_bytes('bcdefgd')
 
3078
        self.assertFalse(decoder.finished_reading)
 
3079
        self.assertEqual(4, decoder.next_read_size())
 
3080
        self.assertEqual('bcdefg', decoder.read_pending_data())
 
3081
        self.assertEqual('', decoder.unused_data)
 
3082
        decoder.accept_bytes('one')
 
3083
        self.assertFalse(decoder.finished_reading)
 
3084
        self.assertEqual(1, decoder.next_read_size())
 
3085
        self.assertEqual('', decoder.read_pending_data())
 
3086
        self.assertEqual('', decoder.unused_data)
 
3087
        decoder.accept_bytes('\nblarg')
 
3088
        self.assertTrue(decoder.finished_reading)
 
3089
        self.assertEqual(1, decoder.next_read_size())
 
3090
        self.assertEqual('', decoder.read_pending_data())
 
3091
        self.assertEqual('blarg', decoder.unused_data)
 
3092
        
 
3093
    def test_accept_bytes_all_at_once_with_excess(self):
 
3094
        decoder = protocol.LengthPrefixedBodyDecoder()
 
3095
        decoder.accept_bytes('1\nadone\nunused')
 
3096
        self.assertTrue(decoder.finished_reading)
 
3097
        self.assertEqual(1, decoder.next_read_size())
 
3098
        self.assertEqual('a', decoder.read_pending_data())
 
3099
        self.assertEqual('unused', decoder.unused_data)
 
3100
 
 
3101
    def test_accept_bytes_exact_end_of_body(self):
 
3102
        decoder = protocol.LengthPrefixedBodyDecoder()
 
3103
        decoder.accept_bytes('1\na')
 
3104
        self.assertFalse(decoder.finished_reading)
 
3105
        self.assertEqual(5, decoder.next_read_size())
 
3106
        self.assertEqual('a', decoder.read_pending_data())
 
3107
        self.assertEqual('', decoder.unused_data)
 
3108
        decoder.accept_bytes('done\n')
 
3109
        self.assertTrue(decoder.finished_reading)
 
3110
        self.assertEqual(1, decoder.next_read_size())
 
3111
        self.assertEqual('', decoder.read_pending_data())
 
3112
        self.assertEqual('', decoder.unused_data)
 
3113
 
 
3114
 
 
3115
class TestChunkedBodyDecoder(tests.TestCase):
 
3116
    """Tests for ChunkedBodyDecoder.
 
3117
    
 
3118
    This is the body decoder used for protocol version two.
 
3119
    """
 
3120
 
 
3121
    def test_construct(self):
 
3122
        decoder = protocol.ChunkedBodyDecoder()
 
3123
        self.assertFalse(decoder.finished_reading)
 
3124
        self.assertEqual(8, decoder.next_read_size())
 
3125
        self.assertEqual(None, decoder.read_next_chunk())
 
3126
        self.assertEqual('', decoder.unused_data)
 
3127
 
 
3128
    def test_empty_content(self):
 
3129
        """'chunked\nEND\n' is the complete encoding of a zero-length body.
 
3130
        """
 
3131
        decoder = protocol.ChunkedBodyDecoder()
 
3132
        decoder.accept_bytes('chunked\n')
 
3133
        decoder.accept_bytes('END\n')
 
3134
        self.assertTrue(decoder.finished_reading)
 
3135
        self.assertEqual(None, decoder.read_next_chunk())
 
3136
        self.assertEqual('', decoder.unused_data)
 
3137
 
 
3138
    def test_one_chunk(self):
 
3139
        """A body in a single chunk is decoded correctly."""
 
3140
        decoder = protocol.ChunkedBodyDecoder()
 
3141
        decoder.accept_bytes('chunked\n')
 
3142
        chunk_length = 'f\n'
 
3143
        chunk_content = '123456789abcdef'
 
3144
        finish = 'END\n'
 
3145
        decoder.accept_bytes(chunk_length + chunk_content + finish)
 
3146
        self.assertTrue(decoder.finished_reading)
 
3147
        self.assertEqual(chunk_content, decoder.read_next_chunk())
 
3148
        self.assertEqual('', decoder.unused_data)
 
3149
        
 
3150
    def test_incomplete_chunk(self):
 
3151
        """When there are less bytes in the chunk than declared by the length,
 
3152
        then we haven't finished reading yet.
 
3153
        """
 
3154
        decoder = protocol.ChunkedBodyDecoder()
 
3155
        decoder.accept_bytes('chunked\n')
 
3156
        chunk_length = '8\n'
 
3157
        three_bytes = '123'
 
3158
        decoder.accept_bytes(chunk_length + three_bytes)
 
3159
        self.assertFalse(decoder.finished_reading)
 
3160
        self.assertEqual(
 
3161
            5 + 4, decoder.next_read_size(),
 
3162
            "The next_read_size hint should be the number of missing bytes in "
 
3163
            "this chunk plus 4 (the length of the end-of-body marker: "
 
3164
            "'END\\n')")
 
3165
        self.assertEqual(None, decoder.read_next_chunk())
 
3166
 
 
3167
    def test_incomplete_length(self):
 
3168
        """A chunk length hasn't been read until a newline byte has been read.
 
3169
        """
 
3170
        decoder = protocol.ChunkedBodyDecoder()
 
3171
        decoder.accept_bytes('chunked\n')
 
3172
        decoder.accept_bytes('9')
 
3173
        self.assertEqual(
 
3174
            1, decoder.next_read_size(),
 
3175
            "The next_read_size hint should be 1, because we don't know the "
 
3176
            "length yet.")
 
3177
        decoder.accept_bytes('\n')
 
3178
        self.assertEqual(
 
3179
            9 + 4, decoder.next_read_size(),
 
3180
            "The next_read_size hint should be the length of the chunk plus 4 "
 
3181
            "(the length of the end-of-body marker: 'END\\n')")
 
3182
        self.assertFalse(decoder.finished_reading)
 
3183
        self.assertEqual(None, decoder.read_next_chunk())
 
3184
 
 
3185
    def test_two_chunks(self):
 
3186
        """Content from multiple chunks is concatenated."""
 
3187
        decoder = protocol.ChunkedBodyDecoder()
 
3188
        decoder.accept_bytes('chunked\n')
 
3189
        chunk_one = '3\naaa'
 
3190
        chunk_two = '5\nbbbbb'
 
3191
        finish = 'END\n'
 
3192
        decoder.accept_bytes(chunk_one + chunk_two + finish)
 
3193
        self.assertTrue(decoder.finished_reading)
 
3194
        self.assertEqual('aaa', decoder.read_next_chunk())
 
3195
        self.assertEqual('bbbbb', decoder.read_next_chunk())
 
3196
        self.assertEqual(None, decoder.read_next_chunk())
 
3197
        self.assertEqual('', decoder.unused_data)
 
3198
 
 
3199
    def test_excess_bytes(self):
 
3200
        """Bytes after the chunked body are reported as unused bytes."""
 
3201
        decoder = protocol.ChunkedBodyDecoder()
 
3202
        decoder.accept_bytes('chunked\n')
 
3203
        chunked_body = "5\naaaaaEND\n"
 
3204
        excess_bytes = "excess bytes"
 
3205
        decoder.accept_bytes(chunked_body + excess_bytes)
 
3206
        self.assertTrue(decoder.finished_reading)
 
3207
        self.assertEqual('aaaaa', decoder.read_next_chunk())
 
3208
        self.assertEqual(excess_bytes, decoder.unused_data)
 
3209
        self.assertEqual(
 
3210
            1, decoder.next_read_size(),
 
3211
            "next_read_size hint should be 1 when finished_reading.")
 
3212
 
 
3213
    def test_multidigit_length(self):
 
3214
        """Lengths in the chunk prefixes can have multiple digits."""
 
3215
        decoder = protocol.ChunkedBodyDecoder()
 
3216
        decoder.accept_bytes('chunked\n')
 
3217
        length = 0x123
 
3218
        chunk_prefix = hex(length) + '\n'
 
3219
        chunk_bytes = 'z' * length
 
3220
        finish = 'END\n'
 
3221
        decoder.accept_bytes(chunk_prefix + chunk_bytes + finish)
 
3222
        self.assertTrue(decoder.finished_reading)
 
3223
        self.assertEqual(chunk_bytes, decoder.read_next_chunk())
 
3224
 
 
3225
    def test_byte_at_a_time(self):
 
3226
        """A complete body fed to the decoder one byte at a time should not
 
3227
        confuse the decoder.  That is, it should give the same result as if the
 
3228
        bytes had been received in one batch.
 
3229
 
 
3230
        This test is the same as test_one_chunk apart from the way accept_bytes
 
3231
        is called.
 
3232
        """
 
3233
        decoder = protocol.ChunkedBodyDecoder()
 
3234
        decoder.accept_bytes('chunked\n')
 
3235
        chunk_length = 'f\n'
 
3236
        chunk_content = '123456789abcdef'
 
3237
        finish = 'END\n'
 
3238
        for byte in (chunk_length + chunk_content + finish):
 
3239
            decoder.accept_bytes(byte)
 
3240
        self.assertTrue(decoder.finished_reading)
 
3241
        self.assertEqual(chunk_content, decoder.read_next_chunk())
 
3242
        self.assertEqual('', decoder.unused_data)
 
3243
 
 
3244
    def test_read_pending_data_resets(self):
 
3245
        """read_pending_data does not return the same bytes twice."""
 
3246
        decoder = protocol.ChunkedBodyDecoder()
 
3247
        decoder.accept_bytes('chunked\n')
 
3248
        chunk_one = '3\naaa'
 
3249
        chunk_two = '3\nbbb'
 
3250
        finish = 'END\n'
 
3251
        decoder.accept_bytes(chunk_one)
 
3252
        self.assertEqual('aaa', decoder.read_next_chunk())
 
3253
        decoder.accept_bytes(chunk_two)
 
3254
        self.assertEqual('bbb', decoder.read_next_chunk())
 
3255
        self.assertEqual(None, decoder.read_next_chunk())
 
3256
 
 
3257
    def test_decode_error(self):
 
3258
        decoder = protocol.ChunkedBodyDecoder()
 
3259
        decoder.accept_bytes('chunked\n')
 
3260
        chunk_one = 'b\nfirst chunk'
 
3261
        error_signal = 'ERR\n'
 
3262
        error_chunks = '5\npart1' + '5\npart2'
 
3263
        finish = 'END\n'
 
3264
        decoder.accept_bytes(chunk_one + error_signal + error_chunks + finish)
 
3265
        self.assertTrue(decoder.finished_reading)
 
3266
        self.assertEqual('first chunk', decoder.read_next_chunk())
 
3267
        expected_failure = _mod_request.FailedSmartServerResponse(
 
3268
            ('part1', 'part2'))
 
3269
        self.assertEqual(expected_failure, decoder.read_next_chunk())
 
3270
 
 
3271
    def test_bad_header(self):
 
3272
        """accept_bytes raises a SmartProtocolError if a chunked body does not
 
3273
        start with the right header.
 
3274
        """
 
3275
        decoder = protocol.ChunkedBodyDecoder()
 
3276
        self.assertRaises(
 
3277
            errors.SmartProtocolError, decoder.accept_bytes, 'bad header\n')
 
3278
 
 
3279
 
 
3280
class TestSuccessfulSmartServerResponse(tests.TestCase):
 
3281
 
 
3282
    def test_construct_no_body(self):
 
3283
        response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'))
 
3284
        self.assertEqual(('foo', 'bar'), response.args)
 
3285
        self.assertEqual(None, response.body)
 
3286
 
 
3287
    def test_construct_with_body(self):
 
3288
        response = _mod_request.SuccessfulSmartServerResponse(('foo', 'bar'),
 
3289
                                                              'bytes')
 
3290
        self.assertEqual(('foo', 'bar'), response.args)
 
3291
        self.assertEqual('bytes', response.body)
 
3292
        # repr(response) doesn't trigger exceptions.
 
3293
        repr(response)
 
3294
 
 
3295
    def test_construct_with_body_stream(self):
 
3296
        bytes_iterable = ['abc']
 
3297
        response = _mod_request.SuccessfulSmartServerResponse(
 
3298
            ('foo', 'bar'), body_stream=bytes_iterable)
 
3299
        self.assertEqual(('foo', 'bar'), response.args)
 
3300
        self.assertEqual(bytes_iterable, response.body_stream)
 
3301
 
 
3302
    def test_construct_rejects_body_and_body_stream(self):
 
3303
        """'body' and 'body_stream' are mutually exclusive."""
 
3304
        self.assertRaises(
 
3305
            errors.BzrError,
 
3306
            _mod_request.SuccessfulSmartServerResponse, (), 'body', ['stream'])
 
3307
 
 
3308
    def test_is_successful(self):
 
3309
        """is_successful should return True for SuccessfulSmartServerResponse."""
 
3310
        response = _mod_request.SuccessfulSmartServerResponse(('error',))
 
3311
        self.assertEqual(True, response.is_successful())
 
3312
 
 
3313
 
 
3314
class TestFailedSmartServerResponse(tests.TestCase):
 
3315
 
 
3316
    def test_construct(self):
 
3317
        response = _mod_request.FailedSmartServerResponse(('foo', 'bar'))
 
3318
        self.assertEqual(('foo', 'bar'), response.args)
 
3319
        self.assertEqual(None, response.body)
 
3320
        response = _mod_request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
 
3321
        self.assertEqual(('foo', 'bar'), response.args)
 
3322
        self.assertEqual('bytes', response.body)
 
3323
        # repr(response) doesn't trigger exceptions.
 
3324
        repr(response)
 
3325
 
 
3326
    def test_is_successful(self):
 
3327
        """is_successful should return False for FailedSmartServerResponse."""
 
3328
        response = _mod_request.FailedSmartServerResponse(('error',))
 
3329
        self.assertEqual(False, response.is_successful())
 
3330
 
 
3331
 
 
3332
class FakeHTTPMedium(object):
 
3333
    def __init__(self):
 
3334
        self.written_request = None
 
3335
        self._current_request = None
 
3336
    def send_http_smart_request(self, bytes):
 
3337
        self.written_request = bytes
 
3338
        return None
 
3339
 
 
3340
 
 
3341
class HTTPTunnellingSmokeTest(tests.TestCase):
 
3342
 
 
3343
    def setUp(self):
 
3344
        super(HTTPTunnellingSmokeTest, self).setUp()
 
3345
        # We use the VFS layer as part of HTTP tunnelling tests.
 
3346
        self._captureVar('BZR_NO_SMART_VFS', None)
 
3347
 
 
3348
    def test_smart_http_medium_request_accept_bytes(self):
 
3349
        medium = FakeHTTPMedium()
 
3350
        request = SmartClientHTTPMediumRequest(medium)
 
3351
        request.accept_bytes('abc')
 
3352
        request.accept_bytes('def')
 
3353
        self.assertEqual(None, medium.written_request)
 
3354
        request.finished_writing()
 
3355
        self.assertEqual('abcdef', medium.written_request)
 
3356
 
 
3357
 
 
3358
class RemoteHTTPTransportTestCase(tests.TestCase):
 
3359
 
 
3360
    def test_remote_path_after_clone_child(self):
 
3361
        # If a user enters "bzr+http://host/foo", we want to sent all smart
 
3362
        # requests for child URLs of that to the original URL.  i.e., we want to
 
3363
        # POST to "bzr+http://host/foo/.bzr/smart" and never something like
 
3364
        # "bzr+http://host/foo/.bzr/branch/.bzr/smart".  So, a cloned
 
3365
        # RemoteHTTPTransport remembers the initial URL, and adjusts the relpaths
 
3366
        # it sends in smart requests accordingly.
 
3367
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
 
3368
        new_transport = base_transport.clone('child_dir')
 
3369
        self.assertEqual(base_transport._http_transport,
 
3370
                         new_transport._http_transport)
 
3371
        self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
 
3372
        self.assertEqual(
 
3373
            'child_dir/',
 
3374
            new_transport._client.remote_path_from_transport(new_transport))
 
3375
 
 
3376
    def test_remote_path_unnormal_base(self):
 
3377
        # If the transport's base isn't normalised, the _remote_path should
 
3378
        # still be calculated correctly.
 
3379
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
 
3380
        self.assertEqual('c', base_transport._remote_path('c'))
 
3381
 
 
3382
    def test_clone_unnormal_base(self):
 
3383
        # If the transport's base isn't normalised, cloned transports should
 
3384
        # still work correctly.
 
3385
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
 
3386
        new_transport = base_transport.clone('c')
 
3387
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
3388
        self.assertEqual(
 
3389
            'c/',
 
3390
            new_transport._client.remote_path_from_transport(new_transport))
 
3391
 
 
3392
        
 
3393
# TODO: Client feature that does get_bundle and then installs that into a
 
3394
# branch; this can be used in place of the regular pull/fetch operation when
 
3395
# coming from a smart server.
 
3396
#
 
3397
# TODO: Eventually, want to do a 'branch' command by fetching the whole
 
3398
# history as one big bundle.  How?  
 
3399
#
 
3400
# The branch command does 'br_from.sprout', which tries to preserve the same
 
3401
# format.  We don't necessarily even want that.  
 
3402
#
 
3403
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
 
3404
# operation from one branch into another.  It already has some code for
 
3405
# pulling from a bundle, which it does by trying to see if the destination is
 
3406
# a bundle file.  So it seems the logic for pull ought to be:
 
3407
 
3408
#  - if it's a smart server, get a bundle from there and install that
 
3409
#  - if it's a bundle, install that
 
3410
#  - if it's a branch, pull from there
 
3411
#
 
3412
# Getting a bundle from a smart server is a bit different from reading a
 
3413
# bundle from a URL:
 
3414
#
 
3415
#  - we can reasonably remember the URL we last read from 
 
3416
#  - you can specify a revision number to pull, and we need to pass it across
 
3417
#    to the server as a limit on what will be requested
 
3418
#
 
3419
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
 
3420
# otherwise whether it's a bundle?)  Should this be a property or method of
 
3421
# the transport?  For the ssh protocol, we always know it's a smart server.
 
3422
# For http, we potentially need to probe.  But if we're explicitly given
 
3423
# bzr+http:// then we can skip that for now.