/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: John Arbash Meinel
  • Date: 2007-04-12 20:36:40 UTC
  • mfrom: (2413 +trunk)
  • mto: This revision was merged to the branch mainline in revision 2566.
  • Revision ID: john@arbash-meinel.com-20070412203640-z1jld315288moxvy
[merge] bzr.dev 2413

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
30
30
        tests,
31
31
        urlutils,
32
32
        )
 
33
from bzrlib.smart import (
 
34
        medium,
 
35
        protocol,
 
36
        request,
 
37
        server,
 
38
        vfs,
 
39
)
33
40
from bzrlib.tests.HTTPTestUtil import (
34
41
        HTTPServerWithSmarts,
35
42
        SmartRequestHandler,
38
45
        get_transport,
39
46
        local,
40
47
        memory,
41
 
        smart,
 
48
        remote,
42
49
        )
43
50
from bzrlib.transport.http import SmartClientHTTPMediumRequest
44
51
 
85
92
        sock.bind(('127.0.0.1', 0))
86
93
        sock.listen(1)
87
94
        port = sock.getsockname()[1]
88
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', port)
89
 
        return sock, medium
 
95
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port)
 
96
        return sock, client_medium
90
97
 
91
98
    def receive_bytes_on_server(self, sock, bytes):
92
99
        """Accept a connection on sock and read 3 bytes.
108
115
        # this just ensures that the constructor stays parameter-free which
109
116
        # is important for reuse : some subclasses will dynamically connect,
110
117
        # others are always on, etc.
111
 
        medium = smart.SmartClientStreamMedium()
 
118
        client_medium = medium.SmartClientStreamMedium()
112
119
 
113
120
    def test_construct_smart_client_medium(self):
114
121
        # the base client medium takes no parameters
115
 
        medium = smart.SmartClientMedium()
 
122
        client_medium = medium.SmartClientMedium()
116
123
    
117
124
    def test_construct_smart_simple_pipes_client_medium(self):
118
125
        # the SimplePipes client medium takes two pipes:
119
126
        # readable pipe, writeable pipe.
120
127
        # Constructing one should just save these and do nothing.
121
128
        # We test this by passing in None.
122
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
129
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
123
130
        
124
131
    def test_simple_pipes_client_request_type(self):
125
132
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
126
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
127
 
        request = medium.get_request()
128
 
        self.assertIsInstance(request, smart.SmartClientStreamMediumRequest)
 
133
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
134
        request = client_medium.get_request()
 
135
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
129
136
 
130
137
    def test_simple_pipes_client_get_concurrent_requests(self):
131
138
        # the simple_pipes client does not support pipelined requests:
135
142
        # classes - as the sibling classes share this logic, they do not have
136
143
        # explicit tests for this.
137
144
        output = StringIO()
138
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
139
 
        request = medium.get_request()
 
145
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
146
        request = client_medium.get_request()
140
147
        request.finished_writing()
141
148
        request.finished_reading()
142
 
        request2 = medium.get_request()
 
149
        request2 = client_medium.get_request()
143
150
        request2.finished_writing()
144
151
        request2.finished_reading()
145
152
 
146
153
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
147
154
        # accept_bytes writes to the writeable pipe.
148
155
        output = StringIO()
149
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
150
 
        medium._accept_bytes('abc')
 
156
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
157
        client_medium._accept_bytes('abc')
151
158
        self.assertEqual('abc', output.getvalue())
152
159
    
153
160
    def test_simple_pipes_client_disconnect_does_nothing(self):
154
161
        # calling disconnect does nothing.
155
162
        input = StringIO()
156
163
        output = StringIO()
157
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
164
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
158
165
        # send some bytes to ensure disconnecting after activity still does not
159
166
        # close.
160
 
        medium._accept_bytes('abc')
161
 
        medium.disconnect()
 
167
        client_medium._accept_bytes('abc')
 
168
        client_medium.disconnect()
162
169
        self.assertFalse(input.closed)
163
170
        self.assertFalse(output.closed)
164
171
 
167
174
        # accept_bytes writes to.
168
175
        input = StringIO()
169
176
        output = StringIO()
170
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
171
 
        medium._accept_bytes('abc')
172
 
        medium.disconnect()
173
 
        medium._accept_bytes('abc')
 
177
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
178
        client_medium._accept_bytes('abc')
 
179
        client_medium.disconnect()
 
180
        client_medium._accept_bytes('abc')
174
181
        self.assertFalse(input.closed)
175
182
        self.assertFalse(output.closed)
176
183
        self.assertEqual('abcabc', output.getvalue())
178
185
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
179
186
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
180
187
        # does nothing.
181
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
182
 
        medium.disconnect()
 
188
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
189
        client_medium.disconnect()
183
190
 
184
191
    def test_simple_pipes_client_can_always_read(self):
185
192
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
186
193
        # always tries to read from the underlying pipe.
187
194
        input = StringIO('abcdef')
188
 
        medium = smart.SmartSimplePipesClientMedium(input, None)
189
 
        self.assertEqual('abc', medium.read_bytes(3))
190
 
        medium.disconnect()
191
 
        self.assertEqual('def', medium.read_bytes(3))
 
195
        client_medium = medium.SmartSimplePipesClientMedium(input, None)
 
196
        self.assertEqual('abc', client_medium.read_bytes(3))
 
197
        client_medium.disconnect()
 
198
        self.assertEqual('def', client_medium.read_bytes(3))
192
199
        
193
200
    def test_simple_pipes_client_supports__flush(self):
194
201
        # invoking _flush on a SimplePipesClient should flush the output 
200
207
        flush_calls = []
201
208
        def logging_flush(): flush_calls.append('flush')
202
209
        output.flush = logging_flush
203
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
210
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
204
211
        # this call is here to ensure we only flush once, not on every
205
212
        # _accept_bytes call.
206
 
        medium._accept_bytes('abc')
207
 
        medium._flush()
208
 
        medium.disconnect()
 
213
        client_medium._accept_bytes('abc')
 
214
        client_medium._flush()
 
215
        client_medium.disconnect()
209
216
        self.assertEqual(['flush'], flush_calls)
210
217
 
211
218
    def test_construct_smart_ssh_client_medium(self):
219
226
        unopened_port = sock.getsockname()[1]
220
227
        # having vendor be invalid means that if it tries to connect via the
221
228
        # vendor it will blow up.
222
 
        medium = smart.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
229
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
223
230
            username=None, password=None, vendor="not a vendor")
224
231
        sock.close()
225
232
 
228
235
        # it bytes.
229
236
        output = StringIO()
230
237
        vendor = StringIOSSHVendor(StringIO(), output)
231
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
232
 
            'a password', vendor)
233
 
        medium._accept_bytes('abc')
 
238
        client_medium = medium.SmartSSHClientMedium(
 
239
            'a hostname', 'a port', 'a username', 'a password', vendor)
 
240
        client_medium._accept_bytes('abc')
234
241
        self.assertEqual('abc', output.getvalue())
235
242
        self.assertEqual([('connect_ssh', 'a username', 'a password',
236
243
            'a hostname', 'a port',
247
254
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
248
255
        self.addCleanup(cleanup_environ)
249
256
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
250
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
 
257
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port', 'a username',
251
258
            'a password', vendor)
252
 
        medium._accept_bytes('abc')
 
259
        client_medium._accept_bytes('abc')
253
260
        self.assertEqual('abc', output.getvalue())
254
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
255
262
            'a hostname', 'a port',
262
269
        input = StringIO()
263
270
        output = StringIO()
264
271
        vendor = StringIOSSHVendor(input, output)
265
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
266
 
        medium._accept_bytes('abc')
267
 
        medium.disconnect()
 
272
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
273
        client_medium._accept_bytes('abc')
 
274
        client_medium.disconnect()
268
275
        self.assertTrue(input.closed)
269
276
        self.assertTrue(output.closed)
270
277
        self.assertEqual([
282
289
        input = StringIO()
283
290
        output = StringIO()
284
291
        vendor = StringIOSSHVendor(input, output)
285
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
286
 
        medium._accept_bytes('abc')
287
 
        medium.disconnect()
 
292
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
293
        client_medium._accept_bytes('abc')
 
294
        client_medium.disconnect()
288
295
        # the disconnect has closed output, so we need a new output for the
289
296
        # new connection to write to.
290
297
        input2 = StringIO()
291
298
        output2 = StringIO()
292
299
        vendor.read_from = input2
293
300
        vendor.write_to = output2
294
 
        medium._accept_bytes('abc')
295
 
        medium.disconnect()
 
301
        client_medium._accept_bytes('abc')
 
302
        client_medium.disconnect()
296
303
        self.assertTrue(input.closed)
297
304
        self.assertTrue(output.closed)
298
305
        self.assertTrue(input2.closed)
310
317
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
311
318
        # Doing a disconnect on a new (and thus unconnected) SSH medium
312
319
        # does not fail.  It's ok to disconnect an unconnected medium.
313
 
        medium = smart.SmartSSHClientMedium(None)
314
 
        medium.disconnect()
 
320
        client_medium = medium.SmartSSHClientMedium(None)
 
321
        client_medium.disconnect()
315
322
 
316
323
    def test_ssh_client_raises_on_read_when_not_connected(self):
317
324
        # Doing a read on a new (and thus unconnected) SSH medium raises
318
325
        # MediumNotConnected.
319
 
        medium = smart.SmartSSHClientMedium(None)
320
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
321
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
326
        client_medium = medium.SmartSSHClientMedium(None)
 
327
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
328
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
322
329
 
323
330
    def test_ssh_client_supports__flush(self):
324
331
        # invoking _flush on a SSHClientMedium should flush the output 
331
338
        def logging_flush(): flush_calls.append('flush')
332
339
        output.flush = logging_flush
333
340
        vendor = StringIOSSHVendor(input, output)
334
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
341
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
335
342
        # this call is here to ensure we only flush once, not on every
336
343
        # _accept_bytes call.
337
 
        medium._accept_bytes('abc')
338
 
        medium._flush()
339
 
        medium.disconnect()
 
344
        client_medium._accept_bytes('abc')
 
345
        client_medium._flush()
 
346
        client_medium.disconnect()
340
347
        self.assertEqual(['flush'], flush_calls)
341
348
        
342
349
    def test_construct_smart_tcp_client_medium(self):
345
352
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
346
353
        sock.bind(('127.0.0.1', 0))
347
354
        unopened_port = sock.getsockname()[1]
348
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
355
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', unopened_port)
349
356
        sock.close()
350
357
 
351
358
    def test_tcp_client_connects_on_first_use(self):
378
385
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
379
386
        # Doing a disconnect on a new (and thus unconnected) TCP medium
380
387
        # does not fail.  It's ok to disconnect an unconnected medium.
381
 
        medium = smart.SmartTCPClientMedium(None, None)
382
 
        medium.disconnect()
 
388
        client_medium = medium.SmartTCPClientMedium(None, None)
 
389
        client_medium.disconnect()
383
390
 
384
391
    def test_tcp_client_raises_on_read_when_not_connected(self):
385
392
        # Doing a read on a new (and thus unconnected) TCP medium raises
386
393
        # MediumNotConnected.
387
 
        medium = smart.SmartTCPClientMedium(None, None)
388
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
389
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
394
        client_medium = medium.SmartTCPClientMedium(None, None)
 
395
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
396
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
390
397
 
391
398
    def test_tcp_client_supports__flush(self):
392
399
        # invoking _flush on a TCPClientMedium should do something useful.
421
428
        # WritingCompleted to prevent bad assumptions on stream environments
422
429
        # breaking the needs of message-based environments.
423
430
        output = StringIO()
424
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
425
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
431
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
432
        request = medium.SmartClientStreamMediumRequest(client_medium)
426
433
        request.finished_writing()
427
434
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
428
435
 
432
439
        # and checking that the pipes get the data.
433
440
        input = StringIO()
434
441
        output = StringIO()
435
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
436
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
442
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
443
        request = medium.SmartClientStreamMediumRequest(client_medium)
437
444
        request.accept_bytes('123')
438
445
        request.finished_writing()
439
446
        request.finished_reading()
444
451
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
445
452
        # the current request to the new SmartClientStreamMediumRequest
446
453
        output = StringIO()
447
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
448
 
        request = smart.SmartClientStreamMediumRequest(medium)
449
 
        self.assertIs(medium._current_request, request)
 
454
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
455
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
456
        self.assertIs(client_medium._current_request, request)
450
457
 
451
458
    def test_construct_while_another_request_active_throws(self):
452
459
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
453
460
        # a non-None _current_request raises TooManyConcurrentRequests.
454
461
        output = StringIO()
455
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
456
 
        medium._current_request = "a"
 
462
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
463
        client_medium._current_request = "a"
457
464
        self.assertRaises(errors.TooManyConcurrentRequests,
458
 
            smart.SmartClientStreamMediumRequest, medium)
 
465
            medium.SmartClientStreamMediumRequest, client_medium)
459
466
 
460
467
    def test_finished_read_clears_current_request(self):
461
468
        # calling finished_reading clears the current request from the requests
462
469
        # medium
463
470
        output = StringIO()
464
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
465
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
471
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
472
        request = medium.SmartClientStreamMediumRequest(client_medium)
466
473
        request.finished_writing()
467
474
        request.finished_reading()
468
 
        self.assertEqual(None, medium._current_request)
 
475
        self.assertEqual(None, client_medium._current_request)
469
476
 
470
477
    def test_finished_read_before_finished_write_errors(self):
471
478
        # calling finished_reading before calling finished_writing triggers a
472
479
        # WritingNotComplete error.
473
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
474
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
480
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
481
        request = medium.SmartClientStreamMediumRequest(client_medium)
475
482
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
476
483
        
477
484
    def test_read_bytes(self):
483
490
        # smoke tests.
484
491
        input = StringIO('321')
485
492
        output = StringIO()
486
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
487
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
493
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
494
        request = medium.SmartClientStreamMediumRequest(client_medium)
488
495
        request.finished_writing()
489
496
        self.assertEqual('321', request.read_bytes(3))
490
497
        request.finished_reading()
496
503
        # WritingNotComplete error because the Smart protocol is designed to be
497
504
        # compatible with strict message based protocols like HTTP where the
498
505
        # request cannot be submitted until the writing has completed.
499
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
500
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
506
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
507
        request = medium.SmartClientStreamMediumRequest(client_medium)
501
508
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
502
509
 
503
510
    def test_read_bytes_after_finished_reading_errors(self):
505
512
        # ReadingCompleted to prevent bad assumptions on stream environments
506
513
        # breaking the needs of message-based environments.
507
514
        output = StringIO()
508
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
509
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
515
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
516
        request = medium.SmartClientStreamMediumRequest(client_medium)
510
517
        request.finished_writing()
511
518
        request.finished_reading()
512
519
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
520
527
        # the default or a parameterized class, but rather use the
521
528
        # TestCaseWithTransport infrastructure to set up a smart server and
522
529
        # transport.
523
 
        self.transport_server = smart.SmartTCPServer_for_testing
 
530
        self.transport_server = server.SmartTCPServer_for_testing
524
531
 
525
532
    def test_plausible_url(self):
526
533
        self.assert_(self.get_url().startswith('bzr://'))
527
534
 
528
535
    def test_probe_transport(self):
529
536
        t = self.get_transport()
530
 
        self.assertIsInstance(t, smart.SmartTransport)
 
537
        self.assertIsInstance(t, remote.SmartTransport)
531
538
 
532
539
    def test_get_medium_from_transport(self):
533
540
        """Remote transport has a medium always, which it can return."""
534
541
        t = self.get_transport()
535
 
        medium = t.get_smart_medium()
536
 
        self.assertIsInstance(medium, smart.SmartClientMedium)
 
542
        smart_medium = t.get_smart_medium()
 
543
        self.assertIsInstance(smart_medium, medium.SmartClientMedium)
537
544
 
538
545
 
539
546
class ErrorRaisingProtocol(object):
568
575
 
569
576
class TestSmartServerStreamMedium(tests.TestCase):
570
577
 
 
578
    def setUp(self):
 
579
        super(TestSmartServerStreamMedium, self).setUp()
 
580
        self._captureVar('BZR_NO_SMART_VFS', None)
 
581
 
571
582
    def portable_socket_pair(self):
572
583
        """Return a pair of TCP sockets connected to each other.
573
584
        
588
599
        to_server = StringIO('hello\n')
589
600
        from_server = StringIO()
590
601
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
591
 
        server = smart.SmartServerPipeStreamMedium(
 
602
        server = medium.SmartServerPipeStreamMedium(
592
603
            to_server, from_server, transport)
593
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
604
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
594
605
                from_server.write)
595
 
        server._serve_one_request(protocol)
 
606
        server._serve_one_request(smart_protocol)
596
607
        self.assertEqual('ok\0011\n',
597
608
                         from_server.getvalue())
598
609
 
601
612
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
602
613
        to_server = StringIO('get\001./testfile\n')
603
614
        from_server = StringIO()
604
 
        server = smart.SmartServerPipeStreamMedium(
 
615
        server = medium.SmartServerPipeStreamMedium(
605
616
            to_server, from_server, transport)
606
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
617
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
607
618
                from_server.write)
608
 
        server._serve_one_request(protocol)
 
619
        server._serve_one_request(smart_protocol)
609
620
        self.assertEqual('ok\n'
610
621
                         '17\n'
611
622
                         'contents\nof\nfile\n'
619
630
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
620
631
        to_server = StringIO('get\001' + utf8_filename + '\n')
621
632
        from_server = StringIO()
622
 
        server = smart.SmartServerPipeStreamMedium(
 
633
        server = medium.SmartServerPipeStreamMedium(
623
634
            to_server, from_server, transport)
624
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
635
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
625
636
                from_server.write)
626
 
        server._serve_one_request(protocol)
 
637
        server._serve_one_request(smart_protocol)
627
638
        self.assertEqual('ok\n'
628
639
                         '17\n'
629
640
                         'contents\nof\nfile\n'
634
645
        sample_request_bytes = 'command\n9\nbulk datadone\n'
635
646
        to_server = StringIO(sample_request_bytes)
636
647
        from_server = StringIO()
637
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
648
        server = medium.SmartServerPipeStreamMedium(
 
649
            to_server, from_server, None)
638
650
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
639
651
        server._serve_one_request(sample_protocol)
640
652
        self.assertEqual('', from_server.getvalue())
644
656
    def test_socket_stream_with_bulk_data(self):
645
657
        sample_request_bytes = 'command\n9\nbulk datadone\n'
646
658
        server_sock, client_sock = self.portable_socket_pair()
647
 
        server = smart.SmartServerSocketStreamMedium(
 
659
        server = medium.SmartServerSocketStreamMedium(
648
660
            server_sock, None)
649
661
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
650
662
        client_sock.sendall(sample_request_bytes)
657
669
    def test_pipe_like_stream_shutdown_detection(self):
658
670
        to_server = StringIO('')
659
671
        from_server = StringIO()
660
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
672
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
661
673
        server._serve_one_request(SampleRequest('x'))
662
674
        self.assertTrue(server.finished)
663
675
        
664
676
    def test_socket_stream_shutdown_detection(self):
665
677
        server_sock, client_sock = self.portable_socket_pair()
666
678
        client_sock.close()
667
 
        server = smart.SmartServerSocketStreamMedium(
 
679
        server = medium.SmartServerSocketStreamMedium(
668
680
            server_sock, None)
669
681
        server._serve_one_request(SampleRequest('x'))
670
682
        self.assertTrue(server.finished)
676
688
        sample_request_bytes = 'command\n'
677
689
        to_server = StringIO(sample_request_bytes * 2)
678
690
        from_server = StringIO()
679
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
691
        server = medium.SmartServerPipeStreamMedium(
 
692
            to_server, from_server, None)
680
693
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
681
694
        server._serve_one_request(first_protocol)
682
695
        self.assertEqual(0, first_protocol.next_read_size())
696
709
        # been received seperately.
697
710
        sample_request_bytes = 'command\n'
698
711
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = smart.SmartServerSocketStreamMedium(
 
712
        server = medium.SmartServerSocketStreamMedium(
700
713
            server_sock, None)
701
714
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
715
        # Put two whole requests on the wire.
723
736
        def close():
724
737
            self.closed = True
725
738
        from_server.close = close
726
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
739
        server = medium.SmartServerPipeStreamMedium(
 
740
            to_server, from_server, None)
727
741
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
728
742
        server._serve_one_request(fake_protocol)
729
743
        self.assertEqual('', from_server.getvalue())
735
749
        # not discard the contents.
736
750
        from StringIO import StringIO
737
751
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = smart.SmartServerSocketStreamMedium(
 
752
        server = medium.SmartServerSocketStreamMedium(
739
753
            server_sock, None)
740
754
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
741
755
        server._serve_one_request(fake_protocol)
749
763
        # not discard the contents.
750
764
        to_server = StringIO('')
751
765
        from_server = StringIO()
752
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
766
        server = medium.SmartServerPipeStreamMedium(
 
767
            to_server, from_server, None)
753
768
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
754
769
        self.assertRaises(
755
770
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
757
772
 
758
773
    def test_socket_stream_keyboard_interrupt_handling(self):
759
774
        server_sock, client_sock = self.portable_socket_pair()
760
 
        server = smart.SmartServerSocketStreamMedium(
 
775
        server = medium.SmartServerSocketStreamMedium(
761
776
            server_sock, None)
762
777
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
763
778
        self.assertRaises(
770
785
 
771
786
    def test_get_error_unexpected(self):
772
787
        """Error reported by server with no specific representation"""
 
788
        self._captureVar('BZR_NO_SMART_VFS', None)
773
789
        class FlakyTransport(object):
 
790
            base = 'a_url'
774
791
            def get_bytes(self, path):
775
792
                raise Exception("some random exception from inside server")
776
 
        server = smart.SmartTCPServer(backing_transport=FlakyTransport())
777
 
        server.start_background_thread()
 
793
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
 
794
        smart_server.start_background_thread()
778
795
        try:
779
 
            transport = smart.SmartTCPTransport(server.get_url())
 
796
            transport = remote.SmartTCPTransport(smart_server.get_url())
780
797
            try:
781
798
                transport.get('something')
782
799
            except errors.TransportError, e:
784
801
            else:
785
802
                self.fail("get did not raise expected error")
786
803
        finally:
787
 
            server.stop_background_thread()
 
804
            smart_server.stop_background_thread()
788
805
 
789
806
 
790
807
class SmartTCPTests(tests.TestCase):
805
822
        if readonly:
806
823
            self.real_backing_transport = self.backing_transport
807
824
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
808
 
        self.server = smart.SmartTCPServer(self.backing_transport)
 
825
        self.server = server.SmartTCPServer(self.backing_transport)
809
826
        self.server.start_background_thread()
810
 
        self.transport = smart.SmartTCPTransport(self.server.get_url())
 
827
        self.transport = remote.SmartTCPTransport(self.server.get_url())
 
828
        self.addCleanup(self.tearDownServer)
811
829
 
812
 
    def tearDown(self):
 
830
    def tearDownServer(self):
813
831
        if getattr(self, 'transport', None):
814
832
            self.transport.disconnect()
 
833
            del self.transport
815
834
        if getattr(self, 'server', None):
816
835
            self.server.stop_background_thread()
817
 
        super(SmartTCPTests, self).tearDown()
818
 
        
 
836
            del self.server
 
837
 
 
838
 
 
839
class TestServerSocketUsage(SmartTCPTests):
 
840
 
 
841
    def test_server_setup_teardown(self):
 
842
        """It should be safe to teardown the server with no requests."""
 
843
        self.setUpServer()
 
844
        server = self.server
 
845
        transport = remote.SmartTCPTransport(self.server.get_url())
 
846
        self.tearDownServer()
 
847
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
848
 
 
849
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
 
850
        """The server should close its listening socket when it's stopped."""
 
851
        self.setUpServer()
 
852
        server = self.server
 
853
        self.transport.has('.')
 
854
        self.tearDownServer()
 
855
        # if the listening socket has closed, we should get a BADFD error
 
856
        # when connecting, rather than a hang.
 
857
        transport = remote.SmartTCPTransport(server.get_url())
 
858
        self.assertRaises(errors.ConnectionError, transport.has, '.')
 
859
 
819
860
 
820
861
class WritableEndToEndTests(SmartTCPTests):
821
862
    """Client to server tests that require a writable transport."""
830
871
 
831
872
    def test_smart_transport_has(self):
832
873
        """Checking for file existence over smart."""
 
874
        self._captureVar('BZR_NO_SMART_VFS', None)
833
875
        self.backing_transport.put_bytes("foo", "contents of foo\n")
834
876
        self.assertTrue(self.transport.has("foo"))
835
877
        self.assertFalse(self.transport.has("non-foo"))
836
878
 
837
879
    def test_smart_transport_get(self):
838
880
        """Read back a file over smart."""
 
881
        self._captureVar('BZR_NO_SMART_VFS', None)
839
882
        self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
840
883
        fp = self.transport.get("foo")
841
884
        self.assertEqual('contents\nof\nfoo\n', fp.read())
845
888
        # The path in a raised NoSuchFile exception should be the precise path
846
889
        # asked for by the client. This gives meaningful and unsurprising errors
847
890
        # for users.
 
891
        self._captureVar('BZR_NO_SMART_VFS', None)
848
892
        try:
849
893
            self.transport.get('not%20a%20file')
850
894
        except errors.NoSuchFile, e:
871
915
 
872
916
    def test_open_dir(self):
873
917
        """Test changing directory"""
 
918
        self._captureVar('BZR_NO_SMART_VFS', None)
874
919
        transport = self.transport
875
920
        self.backing_transport.mkdir('toffee')
876
921
        self.backing_transport.mkdir('toffee/apple')
898
943
 
899
944
    def test_mkdir_error_readonly(self):
900
945
        """TransportNotPossible should be preserved from the backing transport."""
 
946
        self._captureVar('BZR_NO_SMART_VFS', None)
901
947
        self.setUpServer(readonly=True)
902
948
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
903
949
            'foo')
 
950
 
 
951
 
 
952
class TestServerHooks(SmartTCPTests):
 
953
 
 
954
    def capture_server_call(self, backing_url, public_url):
 
955
        """Record a server_started|stopped hook firing."""
 
956
        self.hook_calls.append((backing_url, public_url))
 
957
 
 
958
    def test_server_started_hook(self):
 
959
        """The server_started hook fires when the server is started."""
 
960
        self.hook_calls = []
 
961
        server.SmartTCPServer.hooks.install_hook('server_started',
 
962
            self.capture_server_call)
 
963
        self.setUpServer()
 
964
        # at this point, the server will be starting a thread up.
 
965
        # there is no indicator at the moment, so bodge it by doing a request.
 
966
        self.transport.has('.')
 
967
        self.assertEqual([(self.backing_transport.base, self.transport.base)],
 
968
            self.hook_calls)
 
969
 
 
970
    def test_server_stopped_hook_simple(self):
 
971
        """The server_stopped hook fires when the server is stopped."""
 
972
        self.hook_calls = []
 
973
        server.SmartTCPServer.hooks.install_hook('server_stopped',
 
974
            self.capture_server_call)
 
975
        self.setUpServer()
 
976
        result = [(self.backing_transport.base, self.transport.base)]
 
977
        # check the stopping message isn't emitted up front.
 
978
        self.assertEqual([], self.hook_calls)
 
979
        # nor after a single message
 
980
        self.transport.has('.')
 
981
        self.assertEqual([], self.hook_calls)
 
982
        # clean up the server
 
983
        self.tearDownServer()
 
984
        # now it should have fired.
 
985
        self.assertEqual(result, self.hook_calls)
 
986
 
 
987
# TODO: test that when the server suffers an exception that it calls the
 
988
# server-stopped hook.
 
989
 
 
990
 
 
991
class SmartServerCommandTests(tests.TestCaseWithTransport):
 
992
    """Tests that call directly into the command objects, bypassing the network
 
993
    and the request dispatching.
 
994
    """
904
995
        
905
 
 
906
 
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
907
 
    """Test that call directly into the handler logic, bypassing the network."""
908
 
 
909
 
    def test_construct_request_handler(self):
910
 
        """Constructing a request handler should be easy and set defaults."""
911
 
        handler = smart.SmartServerRequestHandler(None)
912
 
        self.assertFalse(handler.finished_reading)
913
 
 
914
996
    def test_hello(self):
915
 
        handler = smart.SmartServerRequestHandler(None)
916
 
        handler.dispatch_command('hello', ())
917
 
        self.assertEqual(('ok', '1'), handler.response.args)
918
 
        self.assertEqual(None, handler.response.body)
 
997
        cmd = request.HelloRequest(None)
 
998
        response = cmd.execute()
 
999
        self.assertEqual(('ok', '1'), response.args)
 
1000
        self.assertEqual(None, response.body)
919
1001
        
920
1002
    def test_get_bundle(self):
921
1003
        from bzrlib.bundle import serializer
924
1006
        wt.add('hello')
925
1007
        rev_id = wt.commit('add hello')
926
1008
        
927
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
928
 
        handler.dispatch_command('get_bundle', ('.', rev_id))
929
 
        bundle = serializer.read_bundle(StringIO(handler.response.body))
930
 
        self.assertEqual((), handler.response.args)
 
1009
        cmd = request.GetBundleRequest(self.get_transport())
 
1010
        response = cmd.execute('.', rev_id)
 
1011
        bundle = serializer.read_bundle(StringIO(response.body))
 
1012
        self.assertEqual((), response.args)
 
1013
 
 
1014
 
 
1015
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
 
1016
    """Test that call directly into the handler logic, bypassing the network."""
 
1017
 
 
1018
    def setUp(self):
 
1019
        super(SmartServerRequestHandlerTests, self).setUp()
 
1020
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1021
 
 
1022
    def build_handler(self, transport):
 
1023
        """Returns a handler for the commands in protocol version one."""
 
1024
        return request.SmartServerRequestHandler(transport, request.request_handlers)
 
1025
 
 
1026
    def test_construct_request_handler(self):
 
1027
        """Constructing a request handler should be easy and set defaults."""
 
1028
        handler = request.SmartServerRequestHandler(None, None)
 
1029
        self.assertFalse(handler.finished_reading)
 
1030
 
 
1031
    def test_hello(self):
 
1032
        handler = self.build_handler(None)
 
1033
        handler.dispatch_command('hello', ())
 
1034
        self.assertEqual(('ok', '1'), handler.response.args)
 
1035
        self.assertEqual(None, handler.response.body)
 
1036
        
 
1037
    def test_disable_vfs_handler_classes_via_environment(self):
 
1038
        # VFS handler classes will raise an error from "execute" if BZR_NO_SMART_VFS
 
1039
        # is set.
 
1040
        handler = vfs.HasRequest(None)
 
1041
        # set environment variable after construction to make sure it's
 
1042
        # examined.
 
1043
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp has
 
1044
        # called _captureVar, so it will be restored to the right state
 
1045
        # afterwards.
 
1046
        os.environ['BZR_NO_SMART_VFS'] = ''
 
1047
        self.assertRaises(errors.DisabledMethod, handler.execute)
931
1048
 
932
1049
    def test_readonly_exception_becomes_transport_not_possible(self):
933
1050
        """The response for a read-only error is ('ReadOnlyError')."""
934
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1051
        handler = self.build_handler(self.get_readonly_transport())
935
1052
        # send a mkdir for foo, with no explicit mode - should fail.
936
1053
        handler.dispatch_command('mkdir', ('foo', ''))
937
1054
        # and the failure should be an explicit ReadOnlyError
943
1060
 
944
1061
    def test_hello_has_finished_body_on_dispatch(self):
945
1062
        """The 'hello' command should set finished_reading."""
946
 
        handler = smart.SmartServerRequestHandler(None)
 
1063
        handler = self.build_handler(None)
947
1064
        handler.dispatch_command('hello', ())
948
1065
        self.assertTrue(handler.finished_reading)
949
1066
        self.assertNotEqual(None, handler.response)
950
1067
 
951
1068
    def test_put_bytes_non_atomic(self):
952
1069
        """'put_...' should set finished_reading after reading the bytes."""
953
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
1070
        handler = self.build_handler(self.get_transport())
954
1071
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
955
1072
        self.assertFalse(handler.finished_reading)
956
1073
        handler.accept_body('1234')
964
1081
    def test_readv_accept_body(self):
965
1082
        """'readv' should set finished_reading after reading offsets."""
966
1083
        self.build_tree(['a-file'])
967
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1084
        handler = self.build_handler(self.get_readonly_transport())
968
1085
        handler.dispatch_command('readv', ('a-file', ))
969
1086
        self.assertFalse(handler.finished_reading)
970
1087
        handler.accept_body('2,')
979
1096
    def test_readv_short_read_response_contents(self):
980
1097
        """'readv' when a short read occurs sets the response appropriately."""
981
1098
        self.build_tree(['a-file'])
982
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1099
        handler = self.build_handler(self.get_readonly_transport())
983
1100
        handler.dispatch_command('readv', ('a-file', ))
984
1101
        # read beyond the end of the file.
985
1102
        handler.accept_body('100,1')
990
1107
        self.assertEqual(None, handler.response.body)
991
1108
 
992
1109
 
993
 
class SmartTransportRegistration(tests.TestCase):
 
1110
class RemoteTransportRegistration(tests.TestCase):
994
1111
 
995
1112
    def test_registration(self):
996
1113
        t = get_transport('bzr+ssh://example.com/path')
997
 
        self.assertIsInstance(t, smart.SmartSSHTransport)
 
1114
        self.assertIsInstance(t, remote.SmartSSHTransport)
998
1115
        self.assertEqual('example.com', t._host)
999
1116
 
1000
1117
 
1001
 
class TestSmartTransport(tests.TestCase):
 
1118
class TestRemoteTransport(tests.TestCase):
1002
1119
        
1003
1120
    def test_use_connection_factory(self):
1004
 
        # We want to be able to pass a client as a parameter to SmartTransport.
 
1121
        # We want to be able to pass a client as a parameter to RemoteTransport.
1005
1122
        input = StringIO("ok\n3\nbardone\n")
1006
1123
        output = StringIO()
1007
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1008
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1124
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1125
        transport = remote.SmartTransport(
 
1126
            'bzr://localhost/', medium=client_medium)
1009
1127
 
1010
1128
        # We want to make sure the client is used when the first remote
1011
1129
        # method is called.  No data should have been sent, or read.
1023
1141
 
1024
1142
    def test__translate_error_readonly(self):
1025
1143
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1026
 
        medium = smart.SmartClientMedium()
1027
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1144
        client_medium = medium.SmartClientMedium()
 
1145
        transport = remote.SmartTransport(
 
1146
            'bzr://localhost/', medium=client_medium)
1028
1147
        self.assertRaises(errors.TransportNotPossible,
1029
1148
            transport._translate_error, ("ReadOnlyError", ))
1030
1149
 
1031
1150
 
1032
 
class InstrumentedServerProtocol(smart.SmartServerStreamMedium):
 
1151
class InstrumentedServerProtocol(medium.SmartServerStreamMedium):
1033
1152
    """A smart server which is backed by memory and saves its write requests."""
1034
1153
 
1035
1154
    def __init__(self, write_output_list):
1036
 
        smart.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
 
1155
        medium.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
1037
1156
        self._write_output_list = write_output_list
1038
1157
 
1039
1158
 
1052
1171
 
1053
1172
    def setUp(self):
1054
1173
        super(TestSmartProtocol, self).setUp()
 
1174
        # XXX: self.server_to_client doesn't seem to be used.  If so,
 
1175
        # InstrumentedServerProtocol is redundant too.
1055
1176
        self.server_to_client = []
1056
1177
        self.to_server = StringIO()
1057
1178
        self.to_client = StringIO()
1058
 
        self.client_medium = smart.SmartSimplePipesClientMedium(self.to_client,
 
1179
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1059
1180
            self.to_server)
1060
 
        self.client_protocol = smart.SmartClientRequestProtocolOne(
 
1181
        self.client_protocol = protocol.SmartClientRequestProtocolOne(
1061
1182
            self.client_medium)
1062
1183
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1063
 
        self.smart_server_request = smart.SmartServerRequestHandler(None)
 
1184
        self.smart_server_request = request.SmartServerRequestHandler(
 
1185
            None, request.request_handlers)
1064
1186
 
1065
1187
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1066
 
        client, smart_server_request):
 
1188
        client):
1067
1189
        """Check that smart (de)serialises offsets as expected.
1068
1190
        
1069
1191
        We check both serialisation and deserialisation at the same time
1072
1194
        
1073
1195
        :param expected_offsets: a readv offset list.
1074
1196
        :param expected_seralised: an expected serial form of the offsets.
1075
 
        :param smart_server_request: a SmartServerRequestHandler instance.
1076
1197
        """
1077
 
        # XXX: 'smart_server_request' should be a SmartServerRequestProtocol in
1078
 
        # future.
1079
 
        offsets = smart_server_request._deserialise_offsets(expected_serialised)
 
1198
        # XXX: '_deserialise_offsets' should be a method of the
 
1199
        # SmartServerRequestProtocol in future.
 
1200
        readv_cmd = vfs.ReadvRequest(None)
 
1201
        offsets = readv_cmd._deserialise_offsets(expected_serialised)
1080
1202
        self.assertEqual(expected_offsets, offsets)
1081
1203
        serialised = client._serialise_offsets(offsets)
1082
1204
        self.assertEqual(expected_serialised, serialised)
1083
1205
 
1084
1206
    def build_protocol_waiting_for_body(self):
1085
1207
        out_stream = StringIO()
1086
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1087
 
        protocol.has_dispatched = True
1088
 
        protocol.request = smart.SmartServerRequestHandler(None)
1089
 
        def handle_end_of_bytes():
1090
 
            self.end_received = True
1091
 
            self.assertEqual('abcdefg', protocol.request._body_bytes)
1092
 
            protocol.request.response = smart.SmartServerResponse(('ok', ))
1093
 
        protocol.request._end_of_body_handler = handle_end_of_bytes
 
1208
        smart_protocol = protocol.SmartServerRequestProtocolOne(None,
 
1209
                out_stream.write)
 
1210
        smart_protocol.has_dispatched = True
 
1211
        smart_protocol.request = self.smart_server_request
 
1212
        class FakeCommand(object):
 
1213
            def do_body(cmd, body_bytes):
 
1214
                self.end_received = True
 
1215
                self.assertEqual('abcdefg', body_bytes)
 
1216
                return request.SmartServerResponse(('ok', ))
 
1217
        smart_protocol.request._command = FakeCommand()
1094
1218
        # Call accept_bytes to make sure that internal state like _body_decoder
1095
1219
        # is initialised.  This test should probably be given a clearer
1096
1220
        # interface to work with that will not cause this inconsistency.
1097
1221
        #   -- Andrew Bennetts, 2006-09-28
1098
 
        protocol.accept_bytes('')
1099
 
        return protocol
 
1222
        smart_protocol.accept_bytes('')
 
1223
        return smart_protocol
1100
1224
 
1101
1225
    def test_construct_version_one_server_protocol(self):
1102
 
        protocol = smart.SmartServerRequestProtocolOne(None, None)
1103
 
        self.assertEqual('', protocol.excess_buffer)
1104
 
        self.assertEqual('', protocol.in_buffer)
1105
 
        self.assertFalse(protocol.has_dispatched)
1106
 
        self.assertEqual(1, protocol.next_read_size())
 
1226
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
 
1227
        self.assertEqual('', smart_protocol.excess_buffer)
 
1228
        self.assertEqual('', smart_protocol.in_buffer)
 
1229
        self.assertFalse(smart_protocol.has_dispatched)
 
1230
        self.assertEqual(1, smart_protocol.next_read_size())
1107
1231
 
1108
1232
    def test_construct_version_one_client_protocol(self):
1109
1233
        # we can construct a client protocol from a client medium request
1110
1234
        output = StringIO()
1111
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
1112
 
        request = medium.get_request()
1113
 
        client_protocol = smart.SmartClientRequestProtocolOne(request)
 
1235
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1236
        request = client_medium.get_request()
 
1237
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1114
1238
 
1115
1239
    def test_server_offset_serialisation(self):
1116
1240
        """The Smart protocol serialises offsets as a comma and \n string.
1119
1243
        one with the order of reads not increasing (an out of order read), and
1120
1244
        one that should coalesce.
1121
1245
        """
1122
 
        self.assertOffsetSerialisation([], '',
1123
 
            self.client_protocol, self.smart_server_request)
1124
 
        self.assertOffsetSerialisation([(1,2)], '1,2',
1125
 
            self.client_protocol, self.smart_server_request)
 
1246
        self.assertOffsetSerialisation([], '', self.client_protocol)
 
1247
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1126
1248
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1127
 
            self.client_protocol, self.smart_server_request)
 
1249
            self.client_protocol)
1128
1250
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1129
 
            '1,2\n3,4\n100,200', self.client_protocol, self.smart_server_request)
 
1251
            '1,2\n3,4\n100,200', self.client_protocol)
1130
1252
 
1131
1253
    def test_accept_bytes_of_bad_request_to_protocol(self):
1132
1254
        out_stream = StringIO()
1133
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1134
 
        protocol.accept_bytes('abc')
1135
 
        self.assertEqual('abc', protocol.in_buffer)
1136
 
        protocol.accept_bytes('\n')
1137
 
        self.assertEqual("error\x01Generic bzr smart protocol error: bad request"
1138
 
            " 'abc'\n", out_stream.getvalue())
1139
 
        self.assertTrue(protocol.has_dispatched)
1140
 
        self.assertEqual(0, protocol.next_read_size())
 
1255
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1256
            None, out_stream.write)
 
1257
        smart_protocol.accept_bytes('abc')
 
1258
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1259
        smart_protocol.accept_bytes('\n')
 
1260
        self.assertEqual(
 
1261
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1262
            out_stream.getvalue())
 
1263
        self.assertTrue(smart_protocol.has_dispatched)
 
1264
        self.assertEqual(0, smart_protocol.next_read_size())
1141
1265
 
1142
1266
    def test_accept_body_bytes_to_protocol(self):
1143
1267
        protocol = self.build_protocol_waiting_for_body()
1150
1274
        self.assertTrue(self.end_received)
1151
1275
 
1152
1276
    def test_accept_request_and_body_all_at_once(self):
 
1277
        self._captureVar('BZR_NO_SMART_VFS', None)
1153
1278
        mem_transport = memory.MemoryTransport()
1154
1279
        mem_transport.put_bytes('foo', 'abcdefghij')
1155
1280
        out_stream = StringIO()
1156
 
        protocol = smart.SmartServerRequestProtocolOne(mem_transport,
 
1281
        smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
1157
1282
                out_stream.write)
1158
 
        protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1159
 
        self.assertEqual(0, protocol.next_read_size())
 
1283
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1284
        self.assertEqual(0, smart_protocol.next_read_size())
1160
1285
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1161
 
        self.assertEqual('', protocol.excess_buffer)
1162
 
        self.assertEqual('', protocol.in_buffer)
 
1286
        self.assertEqual('', smart_protocol.excess_buffer)
 
1287
        self.assertEqual('', smart_protocol.in_buffer)
1163
1288
 
1164
1289
    def test_accept_excess_bytes_are_preserved(self):
1165
1290
        out_stream = StringIO()
1166
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1167
 
        protocol.accept_bytes('hello\nhello\n')
 
1291
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1292
            None, out_stream.write)
 
1293
        smart_protocol.accept_bytes('hello\nhello\n')
1168
1294
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1169
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1170
 
        self.assertEqual("", protocol.in_buffer)
 
1295
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1296
        self.assertEqual("", smart_protocol.in_buffer)
1171
1297
 
1172
1298
    def test_accept_excess_bytes_after_body(self):
1173
1299
        protocol = self.build_protocol_waiting_for_body()
1181
1307
 
1182
1308
    def test_accept_excess_bytes_after_dispatch(self):
1183
1309
        out_stream = StringIO()
1184
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1185
 
        protocol.accept_bytes('hello\n')
 
1310
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1311
            None, out_stream.write)
 
1312
        smart_protocol.accept_bytes('hello\n')
1186
1313
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1187
 
        protocol.accept_bytes('hel')
1188
 
        self.assertEqual("hel", protocol.excess_buffer)
1189
 
        protocol.accept_bytes('lo\n')
1190
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1191
 
        self.assertEqual("", protocol.in_buffer)
 
1314
        smart_protocol.accept_bytes('hel')
 
1315
        self.assertEqual("hel", smart_protocol.excess_buffer)
 
1316
        smart_protocol.accept_bytes('lo\n')
 
1317
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1318
        self.assertEqual("", smart_protocol.in_buffer)
1192
1319
 
1193
1320
    def test__send_response_sets_finished_reading(self):
1194
 
        protocol = smart.SmartServerRequestProtocolOne(None, lambda x: None)
1195
 
        self.assertEqual(1, protocol.next_read_size())
1196
 
        protocol._send_response(('x',))
1197
 
        self.assertEqual(0, protocol.next_read_size())
 
1321
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1322
            None, lambda x: None)
 
1323
        self.assertEqual(1, smart_protocol.next_read_size())
 
1324
        smart_protocol._send_response(('x',))
 
1325
        self.assertEqual(0, smart_protocol.next_read_size())
1198
1326
 
1199
1327
    def test_query_version(self):
1200
1328
        """query_version on a SmartClientProtocolOne should return a number.
1208
1336
        # the error if the response is a non-understood version.
1209
1337
        input = StringIO('ok\x011\n')
1210
1338
        output = StringIO()
1211
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1212
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1213
 
        self.assertEqual(1, protocol.query_version())
 
1339
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1340
        request = client_medium.get_request()
 
1341
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1342
        self.assertEqual(1, smart_protocol.query_version())
1214
1343
 
1215
1344
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1216
1345
            input_tuples):
1221
1350
        # expected bytes
1222
1351
        for input_tuple in input_tuples:
1223
1352
            server_output = StringIO()
1224
 
            server_protocol = smart.SmartServerRequestProtocolOne(
 
1353
            server_protocol = protocol.SmartServerRequestProtocolOne(
1225
1354
                None, server_output.write)
1226
1355
            server_protocol._send_response(input_tuple)
1227
1356
            self.assertEqual(expected_bytes, server_output.getvalue())
1228
 
        # check the decoding of the client protocol from expected_bytes:
 
1357
        # check the decoding of the client smart_protocol from expected_bytes:
1229
1358
        input = StringIO(expected_bytes)
1230
1359
        output = StringIO()
1231
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1232
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1233
 
        protocol.call('foo')
1234
 
        self.assertEqual(expected_tuple, protocol.read_response_tuple())
 
1360
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1361
        request = client_medium.get_request()
 
1362
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1363
        smart_protocol.call('foo')
 
1364
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
1235
1365
 
1236
1366
    def test_client_call_empty_response(self):
1237
1367
        # protocol.call() can get back an empty tuple as a response. This occurs
1246
1376
            [('a', 'b', '34')])
1247
1377
 
1248
1378
    def test_client_call_with_body_bytes_uploads(self):
1249
 
        # protocol.call_with_upload should length-prefix the bytes onto the 
 
1379
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1250
1380
        # wire.
1251
1381
        expected_bytes = "foo\n7\nabcdefgdone\n"
1252
1382
        input = StringIO("\n")
1253
1383
        output = StringIO()
1254
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1255
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1256
 
        protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1384
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1385
        request = client_medium.get_request()
 
1386
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1387
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1257
1388
        self.assertEqual(expected_bytes, output.getvalue())
1258
1389
 
1259
1390
    def test_client_call_with_body_readv_array(self):
1262
1393
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
1263
1394
        input = StringIO("\n")
1264
1395
        output = StringIO()
1265
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1266
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1267
 
        protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1396
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1397
        request = client_medium.get_request()
 
1398
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1399
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1268
1400
        self.assertEqual(expected_bytes, output.getvalue())
1269
1401
 
1270
1402
    def test_client_read_body_bytes_all(self):
1274
1406
        server_bytes = "ok\n7\n1234567done\n"
1275
1407
        input = StringIO(server_bytes)
1276
1408
        output = StringIO()
1277
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1278
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1279
 
        protocol.call('foo')
1280
 
        protocol.read_response_tuple(True)
1281
 
        self.assertEqual(expected_bytes, protocol.read_body_bytes())
 
1409
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1410
        request = client_medium.get_request()
 
1411
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1412
        smart_protocol.call('foo')
 
1413
        smart_protocol.read_response_tuple(True)
 
1414
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1282
1415
 
1283
1416
    def test_client_read_body_bytes_incremental(self):
1284
1417
        # test reading a few bytes at a time from the body
1290
1423
        server_bytes = "ok\n7\n1234567done\n"
1291
1424
        input = StringIO(server_bytes)
1292
1425
        output = StringIO()
1293
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1294
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1295
 
        protocol.call('foo')
1296
 
        protocol.read_response_tuple(True)
1297
 
        self.assertEqual(expected_bytes[0:2], protocol.read_body_bytes(2))
1298
 
        self.assertEqual(expected_bytes[2:4], protocol.read_body_bytes(2))
1299
 
        self.assertEqual(expected_bytes[4:6], protocol.read_body_bytes(2))
1300
 
        self.assertEqual(expected_bytes[6], protocol.read_body_bytes())
 
1426
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1427
        request = client_medium.get_request()
 
1428
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1429
        smart_protocol.call('foo')
 
1430
        smart_protocol.read_response_tuple(True)
 
1431
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1432
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1433
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1434
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1301
1435
 
1302
1436
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1303
1437
        # cancelling the expected body needs to finish the request, but not
1306
1440
        server_bytes = "ok\n7\n1234567done\n"
1307
1441
        input = StringIO(server_bytes)
1308
1442
        output = StringIO()
1309
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1310
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1311
 
        protocol.call('foo')
1312
 
        protocol.read_response_tuple(True)
1313
 
        protocol.cancel_read_body()
 
1443
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1444
        request = client_medium.get_request()
 
1445
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1446
        smart_protocol.call('foo')
 
1447
        smart_protocol.read_response_tuple(True)
 
1448
        smart_protocol.cancel_read_body()
1314
1449
        self.assertEqual(3, input.tell())
1315
 
        self.assertRaises(errors.ReadingCompleted, protocol.read_body_bytes)
 
1450
        self.assertRaises(
 
1451
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1316
1452
 
1317
1453
 
1318
1454
class LengthPrefixedBodyDecoder(tests.TestCase):
1321
1457
    # something similar to the ProtocolBase method.
1322
1458
 
1323
1459
    def test_construct(self):
1324
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1460
        decoder = protocol.LengthPrefixedBodyDecoder()
1325
1461
        self.assertFalse(decoder.finished_reading)
1326
1462
        self.assertEqual(6, decoder.next_read_size())
1327
1463
        self.assertEqual('', decoder.read_pending_data())
1328
1464
        self.assertEqual('', decoder.unused_data)
1329
1465
 
1330
1466
    def test_accept_bytes(self):
1331
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1467
        decoder = protocol.LengthPrefixedBodyDecoder()
1332
1468
        decoder.accept_bytes('')
1333
1469
        self.assertFalse(decoder.finished_reading)
1334
1470
        self.assertEqual(6, decoder.next_read_size())
1361
1497
        self.assertEqual('blarg', decoder.unused_data)
1362
1498
        
1363
1499
    def test_accept_bytes_all_at_once_with_excess(self):
1364
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1500
        decoder = protocol.LengthPrefixedBodyDecoder()
1365
1501
        decoder.accept_bytes('1\nadone\nunused')
1366
1502
        self.assertTrue(decoder.finished_reading)
1367
1503
        self.assertEqual(1, decoder.next_read_size())
1369
1505
        self.assertEqual('unused', decoder.unused_data)
1370
1506
 
1371
1507
    def test_accept_bytes_exact_end_of_body(self):
1372
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1508
        decoder = protocol.LengthPrefixedBodyDecoder()
1373
1509
        decoder.accept_bytes('1\na')
1374
1510
        self.assertFalse(decoder.finished_reading)
1375
1511
        self.assertEqual(5, decoder.next_read_size())
1393
1529
 
1394
1530
class HTTPTunnellingSmokeTest(tests.TestCaseWithTransport):
1395
1531
    
 
1532
    def setUp(self):
 
1533
        super(HTTPTunnellingSmokeTest, self).setUp()
 
1534
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1535
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1536
 
1396
1537
    def _test_bulk_data(self, url_protocol):
1397
1538
        # We should be able to send and receive bulk data in a single message.
1398
1539
        # The 'readv' command in the smart protocol both sends and receives bulk
1399
1540
        # data, so we use that.
1400
1541
        self.build_tree(['data-file'])
1401
 
        http_server = HTTPServerWithSmarts()
1402
 
        http_server._url_protocol = url_protocol
1403
 
        http_server.setUp()
1404
 
        self.addCleanup(http_server.tearDown)
1405
 
 
1406
 
        http_transport = get_transport(http_server.get_url())
1407
 
 
 
1542
        self.transport_readonly_server = HTTPServerWithSmarts
 
1543
 
 
1544
        http_transport = self.get_readonly_transport()
1408
1545
        medium = http_transport.get_smart_medium()
1409
1546
        #remote_transport = RemoteTransport('fake_url', medium)
1410
 
        remote_transport = smart.SmartTransport('/', medium=medium)
 
1547
        remote_transport = remote.SmartTransport('/', medium=medium)
1411
1548
        self.assertEqual(
1412
1549
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
1413
1550
 
1432
1569
    def _test_http_send_smart_request(self, url_protocol):
1433
1570
        http_server = HTTPServerWithSmarts()
1434
1571
        http_server._url_protocol = url_protocol
1435
 
        http_server.setUp()
 
1572
        http_server.setUp(self.get_vfs_only_server())
1436
1573
        self.addCleanup(http_server.tearDown)
1437
1574
 
1438
1575
        post_body = 'hello\n'
1454
1591
        self._test_http_send_smart_request('http+urllib')
1455
1592
 
1456
1593
    def test_http_server_with_smarts(self):
1457
 
        http_server = HTTPServerWithSmarts()
1458
 
        http_server.setUp()
1459
 
        self.addCleanup(http_server.tearDown)
 
1594
        self.transport_readonly_server = HTTPServerWithSmarts
1460
1595
 
1461
1596
        post_body = 'hello\n'
1462
1597
        expected_reply_body = 'ok\x011\n'
1463
1598
 
1464
 
        smart_server_url = http_server.get_url() + '.bzr/smart'
 
1599
        smart_server_url = self.get_readonly_url('.bzr/smart')
1465
1600
        reply = urllib2.urlopen(smart_server_url, post_body).read()
1466
1601
 
1467
1602
        self.assertEqual(expected_reply_body, reply)
1468
1603
 
1469
1604
    def test_smart_http_server_post_request_handler(self):
1470
 
        http_server = HTTPServerWithSmarts()
1471
 
        http_server.setUp()
1472
 
        self.addCleanup(http_server.tearDown)
1473
 
        httpd = http_server._get_httpd()
 
1605
        self.transport_readonly_server = HTTPServerWithSmarts
 
1606
        httpd = self.get_readonly_server()._get_httpd()
1474
1607
 
1475
1608
        socket = SampleSocket(
1476
1609
            'POST /.bzr/smart HTTP/1.0\r\n'
1512
1645
        else:
1513
1646
            return self.writefile
1514
1647
 
1515
 
        
 
1648
 
1516
1649
# TODO: Client feature that does get_bundle and then installs that into a
1517
1650
# branch; this can be used in place of the regular pull/fetch operation when
1518
1651
# coming from a smart server.