/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Martin Pool
  • Date: 2009-03-13 07:54:48 UTC
  • mfrom: (4144 +trunk)
  • mto: This revision was merged to the branch mainline in revision 4189.
  • Revision ID: mbp@sourcefrog.net-20090313075448-jlz1t7baz7gzipqn
merge trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
68
68
 
69
69
    def __init__(self, vendor):
70
70
        self.vendor = vendor
71
 
    
 
71
 
72
72
    def close(self):
73
73
        self.vendor.calls.append(('close', ))
74
 
        
 
74
 
75
75
    def get_filelike_channels(self):
76
76
        return self.vendor.read_from, self.vendor.write_to
77
77
 
78
78
 
79
79
class _InvalidHostnameFeature(tests.Feature):
80
80
    """Does 'non_existent.invalid' fail to resolve?
81
 
    
 
81
 
82
82
    RFC 2606 states that .invalid is reserved for invalid domain names, and
83
83
    also underscores are not a valid character in domain names.  Despite this,
84
84
    it's possible a badly misconfigured name server might decide to always
132
132
        t = threading.Thread(target=_receive_bytes_on_server)
133
133
        t.start()
134
134
        return t
135
 
    
 
135
 
136
136
    def test_construct_smart_simple_pipes_client_medium(self):
137
137
        # the SimplePipes client medium takes two pipes:
138
138
        # readable pipe, writeable pipe.
139
139
        # Constructing one should just save these and do nothing.
140
140
        # We test this by passing in None.
141
141
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
142
 
        
 
142
 
143
143
    def test_simple_pipes_client_request_type(self):
144
144
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
145
145
        client_medium = medium.SmartSimplePipesClientMedium(None, None, None)
148
148
 
149
149
    def test_simple_pipes_client_get_concurrent_requests(self):
150
150
        # the simple_pipes client does not support pipelined requests:
151
 
        # but it does support serial requests: we construct one after 
 
151
        # but it does support serial requests: we construct one after
152
152
        # another is finished. This is a smoke test testing the integration
153
153
        # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
154
154
        # classes - as the sibling classes share this logic, they do not have
170
170
            None, output, 'base')
171
171
        client_medium._accept_bytes('abc')
172
172
        self.assertEqual('abc', output.getvalue())
173
 
    
 
173
 
174
174
    def test_simple_pipes_client_disconnect_does_nothing(self):
175
175
        # calling disconnect does nothing.
176
176
        input = StringIO()
197
197
        self.assertFalse(input.closed)
198
198
        self.assertFalse(output.closed)
199
199
        self.assertEqual('abcabc', output.getvalue())
200
 
    
 
200
 
201
201
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
202
202
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
203
203
        # does nothing.
212
212
        self.assertEqual('abc', client_medium.read_bytes(3))
213
213
        client_medium.disconnect()
214
214
        self.assertEqual('def', client_medium.read_bytes(3))
215
 
        
 
215
 
216
216
    def test_simple_pipes_client_supports__flush(self):
217
 
        # invoking _flush on a SimplePipesClient should flush the output 
 
217
        # invoking _flush on a SimplePipesClient should flush the output
218
218
        # pipe. We test this by creating an output pipe that records
219
219
        # flush calls made to it.
220
220
        from StringIO import StringIO # get regular StringIO
262
262
            'a hostname', 'a port',
263
263
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
264
264
            vendor.calls)
265
 
    
266
 
    def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
267
 
        # The only thing that initiates a connection from the medium is giving
268
 
        # it bytes.
269
 
        output = StringIO()
270
 
        vendor = StringIOSSHVendor(StringIO(), output)
271
 
        orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
272
 
        def cleanup_environ():
273
 
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
274
 
        self.addCleanup(cleanup_environ)
275
 
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
276
 
        client_medium = self.callDeprecated(
277
 
            ['bzr_remote_path is required as of bzr 0.92'],
278
 
            medium.SmartSSHClientMedium, 'a hostname', 'a port', 'a username',
279
 
            'a password', 'base', vendor)
280
 
        client_medium._accept_bytes('abc')
281
 
        self.assertEqual('abc', output.getvalue())
282
 
        self.assertEqual([('connect_ssh', 'a username', 'a password',
283
 
            'a hostname', 'a port',
284
 
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
285
 
            vendor.calls)
286
 
    
 
265
 
287
266
    def test_ssh_client_changes_command_when_bzr_remote_path_passed(self):
288
267
        # The only thing that initiates a connection from the medium is giving
289
268
        # it bytes.
350
329
            ('close', ),
351
330
            ],
352
331
            vendor.calls)
353
 
    
 
332
 
354
333
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
355
334
        # Doing a disconnect on a new (and thus unconnected) SSH medium
356
335
        # does not fail.  It's ok to disconnect an unconnected medium.
369
348
                          1)
370
349
 
371
350
    def test_ssh_client_supports__flush(self):
372
 
        # invoking _flush on a SSHClientMedium should flush the output 
 
351
        # invoking _flush on a SSHClientMedium should flush the output
373
352
        # pipe. We test this by creating an output pipe that records
374
353
        # flush calls made to it.
375
354
        from StringIO import StringIO # get regular StringIO
387
366
        client_medium._flush()
388
367
        client_medium.disconnect()
389
368
        self.assertEqual(['flush'], flush_calls)
390
 
        
 
369
 
391
370
    def test_construct_smart_tcp_client_medium(self):
392
371
        # the TCP client medium takes a host and a port.  Constructing it won't
393
372
        # connect to anything.
408
387
        t.join()
409
388
        sock.close()
410
389
        self.assertEqual(['abc'], bytes)
411
 
    
 
390
 
412
391
    def test_tcp_client_disconnect_does_so(self):
413
392
        # calling disconnect on the client terminates the connection.
414
393
        # we test this by forcing a short read during a socket.MSG_WAITALL
425
404
        # really did disconnect.
426
405
        medium.disconnect()
427
406
 
428
 
    
 
407
 
429
408
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
430
409
        # Doing a disconnect on a new (and thus unconnected) TCP medium
431
410
        # does not fail.  It's ok to disconnect an unconnected medium.
468
447
 
469
448
class TestSmartClientStreamMediumRequest(tests.TestCase):
470
449
    """Tests the for SmartClientStreamMediumRequest.
471
 
    
472
 
    SmartClientStreamMediumRequest is a helper for the three stream based 
 
450
 
 
451
    SmartClientStreamMediumRequest is a helper for the three stream based
473
452
    mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
474
453
    those three mediums implement the interface it expects.
475
454
    """
476
455
 
477
456
    def test_accept_bytes_after_finished_writing_errors(self):
478
 
        # calling accept_bytes after calling finished_writing raises 
 
457
        # calling accept_bytes after calling finished_writing raises
479
458
        # WritingCompleted to prevent bad assumptions on stream environments
480
459
        # breaking the needs of message-based environments.
481
460
        output = StringIO()
537
516
            None, None, 'base')
538
517
        request = medium.SmartClientStreamMediumRequest(client_medium)
539
518
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
540
 
        
 
519
 
541
520
    def test_read_bytes(self):
542
521
        # read bytes should invoke _read_bytes on the stream medium.
543
522
        # we test this by using the SimplePipes medium - the most trivial one
544
 
        # and checking that the data is supplied. Its possible that a 
 
523
        # and checking that the data is supplied. Its possible that a
545
524
        # faulty implementation could poke at the pipe variables them selves,
546
525
        # but we trust that this will be caught as it will break the integration
547
526
        # smoke tests.
566
545
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
567
546
 
568
547
    def test_read_bytes_after_finished_reading_errors(self):
569
 
        # calling read_bytes after calling finished_reading raises 
 
548
        # calling read_bytes after calling finished_reading raises
570
549
        # ReadingCompleted to prevent bad assumptions on stream environments
571
550
        # breaking the needs of message-based environments.
572
551
        output = StringIO()
604
583
 
605
584
 
606
585
class SampleRequest(object):
607
 
    
 
586
 
608
587
    def __init__(self, expected_bytes):
609
588
        self.accepted_bytes = ''
610
589
        self._finished_reading = False
632
611
 
633
612
    def portable_socket_pair(self):
634
613
        """Return a pair of TCP sockets connected to each other.
635
 
        
 
614
 
636
615
        Unlike socket.socketpair, this should work on Windows.
637
616
        """
638
617
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
643
622
        server_sock, addr = listen_sock.accept()
644
623
        listen_sock.close()
645
624
        return server_sock, client_sock
646
 
    
 
625
 
647
626
    def test_smart_query_version(self):
648
627
        """Feed a canned query version to a server"""
649
628
        # wire-to-wire, using the whole stack
723
702
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
724
703
        server._serve_one_request(SampleRequest('x'))
725
704
        self.assertTrue(server.finished)
726
 
        
 
705
 
727
706
    def test_socket_stream_shutdown_detection(self):
728
707
        server_sock, client_sock = self.portable_socket_pair()
729
708
        client_sock.close()
731
710
            server_sock, None)
732
711
        server._serve_one_request(SampleRequest('x'))
733
712
        self.assertTrue(server.finished)
734
 
        
 
713
 
735
714
    def test_socket_stream_incomplete_request(self):
736
715
        """The medium should still construct the right protocol version even if
737
716
        the initial read only reads part of the request.
815
794
        self.assertEqual('', from_server.getvalue())
816
795
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
817
796
        self.assertFalse(server.finished)
818
 
        
 
797
 
819
798
    def test_socket_stream_with_two_requests(self):
820
799
        # If two requests are read in one go, then two calls to
821
800
        # _serve_one_request should still process both of them as if they had
856
835
        self.assertEqual('', from_server.getvalue())
857
836
        self.assertTrue(self.closed)
858
837
        self.assertTrue(server.finished)
859
 
        
 
838
 
860
839
    def test_socket_stream_error_handling(self):
861
840
        server_sock, client_sock = self.portable_socket_pair()
862
841
        server = medium.SmartServerSocketStreamMedium(
867
846
        # closed.
868
847
        self.assertEqual('', client_sock.recv(1))
869
848
        self.assertTrue(server.finished)
870
 
        
 
849
 
871
850
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
872
851
        to_server = StringIO('')
873
852
        from_server = StringIO()
918
897
        # Any empty request (i.e. no bytes) is detected as protocol version one.
919
898
        server_protocol = self.build_protocol_pipe_like('')
920
899
        self.assertProtocolOne(server_protocol)
921
 
        
 
900
 
922
901
    def test_socket_like_build_protocol_empty_bytes(self):
923
902
        # Any empty request (i.e. no bytes) is detected as protocol version one.
924
903
        server_protocol = self.build_protocol_socket('')
959
938
        self.assertEqual(
960
939
            protocol.build_server_protocol_three, protocol_factory)
961
940
        self.assertEqual('extra bytes', remainder)
962
 
        
 
941
 
963
942
    def test_version_two(self):
964
943
        result = medium._get_protocol_factory_for_bytes(
965
944
            'bzr request 2\nextra bytes')
967
946
        self.assertEqual(
968
947
            protocol.SmartServerRequestProtocolTwo, protocol_factory)
969
948
        self.assertEqual('extra bytes', remainder)
970
 
        
 
949
 
971
950
    def test_version_one(self):
972
951
        """Version one requests have no version markers."""
973
952
        result = medium._get_protocol_factory_for_bytes('anything\n')
975
954
        self.assertEqual(
976
955
            protocol.SmartServerRequestProtocolOne, protocol_factory)
977
956
        self.assertEqual('anything\n', remainder)
978
 
        
 
957
 
979
958
 
980
959
class TestSmartTCPServer(tests.TestCase):
981
960
 
1229
1208
    Note: these tests are rudimentary versions of the command object tests in
1230
1209
    test_smart.py.
1231
1210
    """
1232
 
        
 
1211
 
1233
1212
    def test_hello(self):
1234
1213
        cmd = _mod_request.HelloRequest(None, '/')
1235
1214
        response = cmd.execute()
1236
1215
        self.assertEqual(('ok', '2'), response.args)
1237
1216
        self.assertEqual(None, response.body)
1238
 
        
 
1217
 
1239
1218
    def test_get_bundle(self):
1240
1219
        from bzrlib.bundle import serializer
1241
1220
        wt = self.make_branch_and_tree('.')
1242
1221
        self.build_tree_contents([('hello', 'hello world')])
1243
1222
        wt.add('hello')
1244
1223
        rev_id = wt.commit('add hello')
1245
 
        
 
1224
 
1246
1225
        cmd = _mod_request.GetBundleRequest(self.get_transport(), '/')
1247
1226
        response = cmd.execute('.', rev_id)
1248
1227
        bundle = serializer.read_bundle(StringIO(response.body))
1272
1251
        handler.dispatch_command('hello', ())
1273
1252
        self.assertEqual(('ok', '2'), handler.response.args)
1274
1253
        self.assertEqual(None, handler.response.body)
1275
 
        
 
1254
 
1276
1255
    def test_disable_vfs_handler_classes_via_environment(self):
1277
1256
        # VFS handler classes will raise an error from "execute" if
1278
1257
        # BZR_NO_SMART_VFS is set.
1316
1295
        self.assertTrue(handler.finished_reading)
1317
1296
        self.assertEqual(('ok', ), handler.response.args)
1318
1297
        self.assertEqual(None, handler.response.body)
1319
 
        
 
1298
 
1320
1299
    def test_readv_accept_body(self):
1321
1300
        """'readv' should set finished_reading after reading offsets."""
1322
1301
        self.build_tree(['a-file'])
1363
1342
 
1364
1343
 
1365
1344
class TestRemoteTransport(tests.TestCase):
1366
 
        
 
1345
 
1367
1346
    def test_use_connection_factory(self):
1368
1347
        # We want to be able to pass a client as a parameter to RemoteTransport.
1369
1348
        input = StringIO('ok\n3\nbardone\n')
1468
1447
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1469
1448
        requester):
1470
1449
        """Check that smart (de)serialises offsets as expected.
1471
 
        
 
1450
 
1472
1451
        We check both serialisation and deserialisation at the same time
1473
1452
        to ensure that the round tripping cannot skew: both directions should
1474
1453
        be as expected.
1475
 
        
 
1454
 
1476
1455
        :param expected_offsets: a readv offset list.
1477
1456
        :param expected_seralised: an expected serial form of the offsets.
1478
1457
        """
1489
1468
        smart_protocol._has_dispatched = True
1490
1469
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
1491
1470
            None, _mod_request.request_handlers, '/')
1492
 
        class FakeCommand(object):
1493
 
            def do_body(cmd, body_bytes):
 
1471
        class FakeCommand(_mod_request.SmartServerRequest):
 
1472
            def do_body(self_cmd, body_bytes):
1494
1473
                self.end_received = True
1495
1474
                self.assertEqual('abcdefg', body_bytes)
1496
1475
                return _mod_request.SuccessfulSmartServerResponse(('ok', ))
1497
 
        smart_protocol.request._command = FakeCommand()
 
1476
        smart_protocol.request._command = FakeCommand(None)
1498
1477
        # Call accept_bytes to make sure that internal state like _body_decoder
1499
1478
        # is initialised.  This test should probably be given a clearer
1500
1479
        # interface to work with that will not cause this inconsistency.
1528
1507
        ex = self.assertRaises(errors.ConnectionReset,
1529
1508
            response_handler.read_response_tuple)
1530
1509
        self.assertEqual("Connection closed: "
1531
 
            "please check connectivity and permissions "
1532
 
            "(and try -Dhpss if further diagnosis is required)", str(ex))
 
1510
            "please check connectivity and permissions ",
 
1511
            str(ex))
1533
1512
 
1534
1513
    def test_server_offset_serialisation(self):
1535
1514
        """The Smart protocol serialises offsets as a comma and \n string.
1654
1633
 
1655
1634
    def test_query_version(self):
1656
1635
        """query_version on a SmartClientProtocolOne should return a number.
1657
 
        
 
1636
 
1658
1637
        The protocol provides the query_version because the domain level clients
1659
1638
        may all need to be able to probe for capabilities.
1660
1639
        """
1925
1904
 
1926
1905
    def test_query_version(self):
1927
1906
        """query_version on a SmartClientProtocolTwo should return a number.
1928
 
        
 
1907
 
1929
1908
        The protocol provides the query_version because the domain level clients
1930
1909
        may all need to be able to probe for capabilities.
1931
1910
        """
2270
2249
        self.assertEqual(4, smart_protocol.next_read_size())
2271
2250
 
2272
2251
 
2273
 
class NoOpRequest(_mod_request.SmartServerRequest):
2274
 
 
2275
 
    def do(self):
2276
 
        return _mod_request.SuccessfulSmartServerResponse(())
2277
 
 
2278
 
dummy_registry = {'ARG': NoOpRequest}
2279
 
 
2280
 
 
2281
2252
class LoggingMessageHandler(object):
2282
2253
 
2283
2254
    def __init__(self):
2358
2329
            '\0\0\0\x07' # length prefix
2359
2330
            'l3:ARGe' # ['ARG']
2360
2331
            )
2361
 
        self.assertEqual([('structure', ['ARG'])], event_log)
 
2332
        self.assertEqual([('structure', ('ARG',))], event_log)
2362
2333
 
2363
2334
    def test_decode_multiple_bytes(self):
2364
2335
        """The protocol can decode a multiple 'bytes' message parts."""
2375
2346
            [('bytes', 'first'), ('bytes', 'second')], event_log)
2376
2347
 
2377
2348
 
2378
 
class TestConventionalResponseHandler(tests.TestCase):
 
2349
class TestConventionalResponseHandlerBodyStream(tests.TestCase):
2379
2350
 
2380
2351
    def make_response_handler(self, response_bytes):
2381
2352
        from bzrlib.smart.message import ConventionalResponseHandler
2392
2363
            protocol_decoder, medium_request)
2393
2364
        return response_handler
2394
2365
 
2395
 
    def test_body_stream_interrupted_by_error(self):
2396
 
        interrupted_body_stream = (
2397
 
            'oS' # successful response
2398
 
            's\0\0\0\x02le' # empty args
2399
 
            'b\0\0\0\x09chunk one' # first chunk
2400
 
            'b\0\0\0\x09chunk two' # second chunk
2401
 
            'oE' # error flag
2402
 
            's\0\0\0\x0el5:error3:abce' # bencoded error
2403
 
            'e' # message end
2404
 
            )
 
2366
    def test_interrupted_by_error(self):
2405
2367
        response_handler = self.make_response_handler(interrupted_body_stream)
2406
2368
        stream = response_handler.read_streamed_body()
2407
 
        self.assertEqual('chunk one', stream.next())
2408
 
        self.assertEqual('chunk two', stream.next())
 
2369
        self.assertEqual('aaa', stream.next())
 
2370
        self.assertEqual('bbb', stream.next())
2409
2371
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2410
 
        self.assertEqual(('error', 'abc'), exc.error_tuple)
 
2372
        self.assertEqual(('error', 'Boom!'), exc.error_tuple)
2411
2373
 
2412
 
    def test_body_stream_interrupted_by_connection_lost(self):
 
2374
    def test_interrupted_by_connection_lost(self):
2413
2375
        interrupted_body_stream = (
2414
2376
            'oS' # successful response
2415
2377
            's\0\0\0\x02le' # empty args
2427
2389
        self.assertRaises(
2428
2390
            errors.ConnectionReset, response_handler.read_body_bytes)
2429
2391
 
 
2392
    def test_multiple_bytes_parts(self):
 
2393
        multiple_bytes_parts = (
 
2394
            'oS' # successful response
 
2395
            's\0\0\0\x02le' # empty args
 
2396
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2397
            'b\0\0\0\x0aMore bytes' # more bytes
 
2398
            'e' # message end
 
2399
            )
 
2400
        response_handler = self.make_response_handler(multiple_bytes_parts)
 
2401
        self.assertEqual(
 
2402
            'Some bytes\nMore bytes', response_handler.read_body_bytes())
 
2403
        response_handler = self.make_response_handler(multiple_bytes_parts)
 
2404
        self.assertEqual(
 
2405
            ['Some bytes\n', 'More bytes'],
 
2406
            list(response_handler.read_streamed_body()))
 
2407
 
 
2408
 
 
2409
class FakeResponder(object):
 
2410
 
 
2411
    response_sent = False
 
2412
 
 
2413
    def send_error(self, exc):
 
2414
        raise exc
 
2415
 
 
2416
    def send_response(self, response):
 
2417
        pass
 
2418
 
 
2419
 
 
2420
class TestConventionalRequestHandlerBodyStream(tests.TestCase):
 
2421
    """Tests for ConventionalRequestHandler's handling of request bodies."""
 
2422
 
 
2423
    def make_request_handler(self, request_bytes):
 
2424
        """Make a ConventionalRequestHandler for the given bytes using test
 
2425
        doubles for the request_handler and the responder.
 
2426
        """
 
2427
        from bzrlib.smart.message import ConventionalRequestHandler
 
2428
        request_handler = InstrumentedRequestHandler()
 
2429
        request_handler.response = _mod_request.SuccessfulSmartServerResponse(('arg', 'arg'))
 
2430
        responder = FakeResponder()
 
2431
        message_handler = ConventionalRequestHandler(request_handler, responder)
 
2432
        protocol_decoder = protocol.ProtocolThreeDecoder(message_handler)
 
2433
        # put decoder in desired state (waiting for message parts)
 
2434
        protocol_decoder.state_accept = protocol_decoder._state_accept_expecting_message_part
 
2435
        protocol_decoder.accept_bytes(request_bytes)
 
2436
        return request_handler
 
2437
 
 
2438
    def test_multiple_bytes_parts(self):
 
2439
        """Each bytes part triggers a call to the request_handler's
 
2440
        accept_body method.
 
2441
        """
 
2442
        multiple_bytes_parts = (
 
2443
            's\0\0\0\x07l3:fooe' # args
 
2444
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2445
            'b\0\0\0\x0aMore bytes' # more bytes
 
2446
            'e' # message end
 
2447
            )
 
2448
        request_handler = self.make_request_handler(multiple_bytes_parts)
 
2449
        accept_body_calls = [
 
2450
            call_info[1] for call_info in request_handler.calls
 
2451
            if call_info[0] == 'accept_body']
 
2452
        self.assertEqual(
 
2453
            ['Some bytes\n', 'More bytes'], accept_body_calls)
 
2454
 
 
2455
    def test_error_flag_after_body(self):
 
2456
        body_then_error = (
 
2457
            's\0\0\0\x07l3:fooe' # request args
 
2458
            'b\0\0\0\x0bSome bytes\n' # some bytes
 
2459
            'b\0\0\0\x0aMore bytes' # more bytes
 
2460
            'oE' # error flag
 
2461
            's\0\0\0\x07l3:bare' # error args
 
2462
            'e' # message end
 
2463
            )
 
2464
        request_handler = self.make_request_handler(body_then_error)
 
2465
        self.assertEqual(
 
2466
            [('post_body_error_received', ('bar',)), ('end_received',)],
 
2467
            request_handler.calls[-2:])
 
2468
 
2430
2469
 
2431
2470
class TestMessageHandlerErrors(tests.TestCase):
2432
2471
    """Tests for v3 that unrecognised (but well-formed) requests/responses are
2476
2515
 
2477
2516
    def __init__(self):
2478
2517
        self.calls = []
2479
 
 
2480
 
    def body_chunk_received(self, chunk_bytes):
2481
 
        self.calls.append(('body_chunk_received', chunk_bytes))
 
2518
        self.finished_reading = False
2482
2519
 
2483
2520
    def no_body_received(self):
2484
2521
        self.calls.append(('no_body_received',))
2485
2522
 
2486
 
    def prefixed_body_received(self, body_bytes):
2487
 
        self.calls.append(('prefixed_body_received', body_bytes))
2488
 
 
2489
2523
    def end_received(self):
2490
2524
        self.calls.append(('end_received',))
 
2525
        self.finished_reading = True
 
2526
 
 
2527
    def dispatch_command(self, cmd, args):
 
2528
        self.calls.append(('dispatch_command', cmd, args))
 
2529
 
 
2530
    def accept_body(self, bytes):
 
2531
        self.calls.append(('accept_body', bytes))
 
2532
 
 
2533
    def end_of_body(self):
 
2534
        self.calls.append(('end_of_body',))
 
2535
        self.finished_reading = True
 
2536
 
 
2537
    def post_body_error_received(self, error_args):
 
2538
        self.calls.append(('post_body_error_received', error_args))
2491
2539
 
2492
2540
 
2493
2541
class StubRequest(object):
2529
2577
        # The message handler has been invoked with all the parts of the
2530
2578
        # trivial response: empty headers, status byte, no args, end.
2531
2579
        self.assertEqual(
2532
 
            [('headers', {}), ('byte', 'S'), ('structure', []), ('end',)],
 
2580
            [('headers', {}), ('byte', 'S'), ('structure', ()), ('end',)],
2533
2581
            response_handler.event_log)
2534
2582
 
2535
2583
    def test_incomplete_message(self):
2643
2691
        self.assertEqual(
2644
2692
            ['accept_bytes', 'finished_writing'], medium_request.calls)
2645
2693
 
 
2694
    def test_call_with_body_stream_smoke_test(self):
 
2695
        """A smoke test for ProtocolThreeRequester.call_with_body_stream.
 
2696
 
 
2697
        This test checks that a particular simple invocation of
 
2698
        call_with_body_stream emits the correct bytes for that invocation.
 
2699
        """
 
2700
        requester, output = self.make_client_encoder_and_output()
 
2701
        requester.set_headers({'header name': 'header value'})
 
2702
        stream = ['chunk 1', 'chunk two']
 
2703
        requester.call_with_body_stream(('one arg',), stream)
 
2704
        self.assertEquals(
 
2705
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2706
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
 
2707
            's\x00\x00\x00\x0bl7:one arge' # args
 
2708
            'b\x00\x00\x00\x07chunk 1' # a prefixed body chunk
 
2709
            'b\x00\x00\x00\x09chunk two' # a prefixed body chunk
 
2710
            'e', # end
 
2711
            output.getvalue())
 
2712
 
 
2713
    def test_call_with_body_stream_empty_stream(self):
 
2714
        """call_with_body_stream with an empty stream."""
 
2715
        requester, output = self.make_client_encoder_and_output()
 
2716
        requester.set_headers({})
 
2717
        stream = []
 
2718
        requester.call_with_body_stream(('one arg',), stream)
 
2719
        self.assertEquals(
 
2720
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2721
            '\x00\x00\x00\x02de' # headers
 
2722
            's\x00\x00\x00\x0bl7:one arge' # args
 
2723
            # no body chunks
 
2724
            'e', # end
 
2725
            output.getvalue())
 
2726
 
 
2727
    def test_call_with_body_stream_error(self):
 
2728
        """call_with_body_stream will abort the streamed body with an
 
2729
        error if the stream raises an error during iteration.
 
2730
 
 
2731
        The resulting request will still be a complete message.
 
2732
        """
 
2733
        requester, output = self.make_client_encoder_and_output()
 
2734
        requester.set_headers({})
 
2735
        def stream_that_fails():
 
2736
            yield 'aaa'
 
2737
            yield 'bbb'
 
2738
            raise Exception('Boom!')
 
2739
        self.assertRaises(Exception, requester.call_with_body_stream,
 
2740
            ('one arg',), stream_that_fails())
 
2741
        self.assertEquals(
 
2742
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
2743
            '\x00\x00\x00\x02de' # headers
 
2744
            's\x00\x00\x00\x0bl7:one arge' # args
 
2745
            'b\x00\x00\x00\x03aaa' # body
 
2746
            'b\x00\x00\x00\x03bbb' # more body
 
2747
            'oE' # error flag
 
2748
            's\x00\x00\x00\x09l5:errore' # error args: ('error',)
 
2749
            'e', # end
 
2750
            output.getvalue())
 
2751
 
2646
2752
 
2647
2753
class StubMediumRequest(object):
2648
2754
    """A stub medium request that tracks the number of times accept_bytes is
2660
2766
        self.calls.append('finished_writing')
2661
2767
 
2662
2768
 
 
2769
interrupted_body_stream = (
 
2770
    'oS' # status flag (success)
 
2771
    's\x00\x00\x00\x08l4:argse' # args struct ('args,')
 
2772
    'b\x00\x00\x00\x03aaa' # body part ('aaa')
 
2773
    'b\x00\x00\x00\x03bbb' # body part ('bbb')
 
2774
    'oE' # status flag (error)
 
2775
    's\x00\x00\x00\x10l5:error5:Boom!e' # err struct ('error', 'Boom!')
 
2776
    'e' # EOM
 
2777
    )
 
2778
 
 
2779
 
2663
2780
class TestResponseEncodingProtocolThree(tests.TestCase):
2664
2781
 
2665
2782
    def make_response_encoder(self):
2681
2798
            # end of message
2682
2799
            'e')
2683
2800
 
 
2801
    def test_send_broken_body_stream(self):
 
2802
        encoder, out_stream = self.make_response_encoder()
 
2803
        encoder._headers = {}
 
2804
        def stream_that_fails():
 
2805
            yield 'aaa'
 
2806
            yield 'bbb'
 
2807
            raise Exception('Boom!')
 
2808
        response = _mod_request.SuccessfulSmartServerResponse(
 
2809
            ('args',), body_stream=stream_that_fails())
 
2810
        encoder.send_response(response)
 
2811
        expected_response = (
 
2812
            'bzr message 3 (bzr 1.6)\n'  # protocol marker
 
2813
            '\x00\x00\x00\x02de' # headers dict (empty)
 
2814
            + interrupted_body_stream)
 
2815
        self.assertEqual(expected_response, out_stream.getvalue())
 
2816
 
2684
2817
 
2685
2818
class TestResponseEncoderBufferingProtocolThree(tests.TestCase):
2686
2819
    """Tests for buffering of responses.
2697
2830
        self.assertEqual(
2698
2831
            expected_count, len(self.writes),
2699
2832
            "Too many writes: %r" % (self.writes,))
2700
 
        
 
2833
 
2701
2834
    def test_send_error_writes_just_once(self):
2702
2835
        """An error response is written to the medium all at once."""
2703
2836
        self.responder.send_error(Exception('An exception string.'))
2719
2852
        self.responder.send_response(response)
2720
2853
        self.assertWriteCount(1)
2721
2854
 
2722
 
    def test_send_response_with_body_stream_writes_once_per_chunk(self):
2723
 
        """A normal response with a stream body is written to the medium
2724
 
        writes to the medium once per chunk.
2725
 
        """
 
2855
    def test_send_response_with_body_stream_buffers_writes(self):
 
2856
        """A normal response with a stream body writes to the medium once."""
2726
2857
        # Construct a response with stream with 2 chunks in it.
2727
2858
        response = _mod_request.SuccessfulSmartServerResponse(
2728
2859
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2729
2860
        self.responder.send_response(response)
2730
 
        # We will write 3 times: exactly once for each chunk, plus a final
2731
 
        # write to end the response.
2732
 
        self.assertWriteCount(3)
 
2861
        # We will write just once, despite the multiple chunks, due to
 
2862
        # buffering.
 
2863
        self.assertWriteCount(1)
 
2864
 
 
2865
    def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
 
2866
        """When there are many chunks (>100), multiple writes will occur rather
 
2867
        than buffering indefinitely.
 
2868
        """
 
2869
        # Construct a response with stream with 40 chunks in it.  Every chunk
 
2870
        # triggers 3 buffered writes, so we expect > 100 buffered writes, but <
 
2871
        # 200.
 
2872
        body_stream = ['chunk %d' % count for count in range(40)]
 
2873
        response = _mod_request.SuccessfulSmartServerResponse(
 
2874
            ('arg', 'arg'), body_stream=body_stream)
 
2875
        self.responder.send_response(response)
 
2876
        # The write buffer is flushed every 100 buffered writes, so we expect 2
 
2877
        # actual writes.
 
2878
        self.assertWriteCount(2)
2733
2879
 
2734
2880
 
2735
2881
class TestSmartClientUnicode(tests.TestCase):
2772
2918
 
2773
2919
class MockMedium(medium.SmartClientMedium):
2774
2920
    """A mock medium that can be used to test _SmartClient.
2775
 
    
 
2921
 
2776
2922
    It can be given a series of requests to expect (and responses it should
2777
2923
    return for them).  It can also be told when the client is expected to
2778
2924
    disconnect a medium.  Expectations must be satisfied in the order they are
2790
2936
        super(MockMedium, self).__init__('dummy base')
2791
2937
        self._mock_request = _MockMediumRequest(self)
2792
2938
        self._expected_events = []
2793
 
        
 
2939
 
2794
2940
    def expect_request(self, request_bytes, response_bytes,
2795
2941
                       allow_partial_read=False):
2796
2942
        """Expect 'request_bytes' to be sent, and reply with 'response_bytes'.
2799
2945
        called to send the request.  Similarly, no assumption is made about how
2800
2946
        many times read_bytes/read_line are called by protocol code to read a
2801
2947
        response.  e.g.::
2802
 
        
 
2948
 
2803
2949
            request.accept_bytes('ab')
2804
2950
            request.accept_bytes('cd')
2805
2951
            request.finished_writing()
2806
2952
 
2807
2953
        and::
2808
 
        
 
2954
 
2809
2955
            request.accept_bytes('abcd')
2810
2956
            request.finished_writing()
2811
2957
 
2996
3142
    def test_first_response_is_error(self):
2997
3143
        """If the server replies with an error, then the version detection
2998
3144
        should be complete.
2999
 
        
 
3145
 
3000
3146
        This test is very similar to test_version_two_server, but catches a bug
3001
3147
        we had in the case where the first reply was an error response.
3002
3148
        """
3042
3188
 
3043
3189
class LengthPrefixedBodyDecoder(tests.TestCase):
3044
3190
 
3045
 
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
 
3191
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3046
3192
    # something similar to the ProtocolBase method.
3047
3193
 
3048
3194
    def test_construct(self):
3084
3230
        self.assertEqual(1, decoder.next_read_size())
3085
3231
        self.assertEqual('', decoder.read_pending_data())
3086
3232
        self.assertEqual('blarg', decoder.unused_data)
3087
 
        
 
3233
 
3088
3234
    def test_accept_bytes_all_at_once_with_excess(self):
3089
3235
        decoder = protocol.LengthPrefixedBodyDecoder()
3090
3236
        decoder.accept_bytes('1\nadone\nunused')
3109
3255
 
3110
3256
class TestChunkedBodyDecoder(tests.TestCase):
3111
3257
    """Tests for ChunkedBodyDecoder.
3112
 
    
 
3258
 
3113
3259
    This is the body decoder used for protocol version two.
3114
3260
    """
3115
3261
 
3141
3287
        self.assertTrue(decoder.finished_reading)
3142
3288
        self.assertEqual(chunk_content, decoder.read_next_chunk())
3143
3289
        self.assertEqual('', decoder.unused_data)
3144
 
        
 
3290
 
3145
3291
    def test_incomplete_chunk(self):
3146
3292
        """When there are less bytes in the chunk than declared by the length,
3147
3293
        then we haven't finished reading yet.
3412
3558
        self.assertNotEquals(type(r), type(t))
3413
3559
 
3414
3560
 
3415
 
# TODO: Client feature that does get_bundle and then installs that into a
3416
 
# branch; this can be used in place of the regular pull/fetch operation when
3417
 
# coming from a smart server.
3418
 
#
3419
 
# TODO: Eventually, want to do a 'branch' command by fetching the whole
3420
 
# history as one big bundle.  How?  
3421
 
#
3422
 
# The branch command does 'br_from.sprout', which tries to preserve the same
3423
 
# format.  We don't necessarily even want that.  
3424
 
#
3425
 
# It might be simpler to handle cmd_pull first, which does a simpler fetch()
3426
 
# operation from one branch into another.  It already has some code for
3427
 
# pulling from a bundle, which it does by trying to see if the destination is
3428
 
# a bundle file.  So it seems the logic for pull ought to be:
3429
 
3430
 
#  - if it's a smart server, get a bundle from there and install that
3431
 
#  - if it's a bundle, install that
3432
 
#  - if it's a branch, pull from there
3433
 
#
3434
 
# Getting a bundle from a smart server is a bit different from reading a
3435
 
# bundle from a URL:
3436
 
#
3437
 
#  - we can reasonably remember the URL we last read from 
3438
 
#  - you can specify a revision number to pull, and we need to pass it across
3439
 
#    to the server as a limit on what will be requested
3440
 
#
3441
 
# TODO: Given a URL, determine whether it is a smart server or not (or perhaps
3442
 
# otherwise whether it's a bundle?)  Should this be a property or method of
3443
 
# the transport?  For the ssh protocol, we always know it's a smart server.
3444
 
# For http, we potentially need to probe.  But if we're explicitly given
3445
 
# bzr+http:// then we can skip that for now.