/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_http.py

 * ``Repository.get_data_stream`` is now deprecated in favour of
   ``Repository.get_data_stream_for_search`` which allows less network
   traffic when requesting data streams over a smart server. (Robert Collins)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005, 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Tests for HTTP implementations.
 
18
 
 
19
This module defines a load_tests() method that parametrize tests classes for
 
20
transport implementation, http protocol versions and authentication schemes.
 
21
"""
 
22
 
 
23
# TODO: Should be renamed to bzrlib.transport.http.tests?
 
24
# TODO: What about renaming to bzrlib.tests.transport.http ?
 
25
 
 
26
from cStringIO import StringIO
 
27
import httplib
 
28
import os
 
29
import select
 
30
import SimpleHTTPServer
 
31
import socket
 
32
import sys
 
33
import threading
 
34
 
 
35
import bzrlib
 
36
from bzrlib import (
 
37
    config,
 
38
    errors,
 
39
    osutils,
 
40
    tests,
 
41
    transport,
 
42
    ui,
 
43
    urlutils,
 
44
    )
 
45
from bzrlib.tests import (
 
46
    http_server,
 
47
    http_utils,
 
48
    )
 
49
from bzrlib.transport import (
 
50
    http,
 
51
    remote,
 
52
    )
 
53
from bzrlib.transport.http import (
 
54
    _urllib,
 
55
    _urllib2_wrappers,
 
56
    )
 
57
 
 
58
 
 
59
try:
 
60
    from bzrlib.transport.http._pycurl import PyCurlTransport
 
61
    pycurl_present = True
 
62
except errors.DependencyNotPresent:
 
63
    pycurl_present = False
 
64
 
 
65
 
 
66
class TransportAdapter(tests.TestScenarioApplier):
 
67
    """Generate the same test for each transport implementation."""
 
68
 
 
69
    def __init__(self):
 
70
        transport_scenarios = [
 
71
            ('urllib', dict(_transport=_urllib.HttpTransport_urllib,
 
72
                            _server=http_server.HttpServer_urllib,
 
73
                            _qualified_prefix='http+urllib',)),
 
74
            ]
 
75
        if pycurl_present:
 
76
            transport_scenarios.append(
 
77
                ('pycurl', dict(_transport=PyCurlTransport,
 
78
                                _server=http_server.HttpServer_PyCurl,
 
79
                                _qualified_prefix='http+pycurl',)))
 
80
        self.scenarios = transport_scenarios
 
81
 
 
82
 
 
83
class TransportProtocolAdapter(TransportAdapter):
 
84
    """Generate the same test for each protocol implementation.
 
85
 
 
86
    In addition to the transport adaptatation that we inherit from.
 
87
    """
 
88
 
 
89
    def __init__(self):
 
90
        super(TransportProtocolAdapter, self).__init__()
 
91
        protocol_scenarios = [
 
92
            ('HTTP/1.0',  dict(_protocol_version='HTTP/1.0')),
 
93
            ('HTTP/1.1',  dict(_protocol_version='HTTP/1.1')),
 
94
            ]
 
95
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
96
                                                  protocol_scenarios)
 
97
 
 
98
 
 
99
class TransportProtocolAuthenticationAdapter(TransportProtocolAdapter):
 
100
    """Generate the same test for each authentication scheme implementation.
 
101
 
 
102
    In addition to the protocol adaptatation that we inherit from.
 
103
    """
 
104
 
 
105
    def __init__(self):
 
106
        super(TransportProtocolAuthenticationAdapter, self).__init__()
 
107
        auth_scheme_scenarios = [
 
108
            ('basic', dict(_auth_scheme='basic')),
 
109
            ('digest', dict(_auth_scheme='digest')),
 
110
            ]
 
111
 
 
112
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
113
                                                  auth_scheme_scenarios)
 
114
 
 
115
def load_tests(standard_tests, module, loader):
 
116
    """Multiply tests for http clients and protocol versions."""
 
117
    # one for each transport
 
118
    t_adapter = TransportAdapter()
 
119
    t_classes= (TestHttpTransportRegistration,
 
120
                TestHttpTransportUrls,
 
121
                )
 
122
    is_testing_for_transports = tests.condition_isinstance(t_classes)
 
123
 
 
124
    # multiplied by one for each protocol version
 
125
    tp_adapter = TransportProtocolAdapter()
 
126
    tp_classes= (SmartHTTPTunnellingTest,
 
127
                 TestDoCatchRedirections,
 
128
                 TestHTTPConnections,
 
129
                 TestHTTPRedirections,
 
130
                 TestHTTPSilentRedirections,
 
131
                 TestLimitedRangeRequestServer,
 
132
                 TestPost,
 
133
                 TestProxyHttpServer,
 
134
                 TestRanges,
 
135
                 TestSpecificRequestHandler,
 
136
                 )
 
137
    is_also_testing_for_protocols = tests.condition_isinstance(tp_classes)
 
138
 
 
139
    # multiplied by one for each authentication scheme
 
140
    tpa_adapter = TransportProtocolAuthenticationAdapter()
 
141
    tpa_classes = (TestAuth,
 
142
                   )
 
143
    is_also_testing_for_authentication = tests.condition_isinstance(
 
144
        tpa_classes)
 
145
 
 
146
    result = loader.suiteClass()
 
147
    for test_class in tests.iter_suite_tests(standard_tests):
 
148
        # Each test class is either standalone or testing for some combination
 
149
        # of transport, protocol version, authentication scheme. Use the right
 
150
        # adpater (or none) depending on the class.
 
151
        if is_testing_for_transports(test_class):
 
152
            result.addTests(t_adapter.adapt(test_class))
 
153
        elif is_also_testing_for_protocols(test_class):
 
154
            result.addTests(tp_adapter.adapt(test_class))
 
155
        elif is_also_testing_for_authentication(test_class):
 
156
            result.addTests(tpa_adapter.adapt(test_class))
 
157
        else:
 
158
            result.addTest(test_class)
 
159
    return result
 
160
 
 
161
 
 
162
class FakeManager(object):
 
163
 
 
164
    def __init__(self):
 
165
        self.credentials = []
 
166
 
 
167
    def add_password(self, realm, host, username, password):
 
168
        self.credentials.append([realm, host, username, password])
 
169
 
 
170
 
 
171
class RecordingServer(object):
 
172
    """A fake HTTP server.
 
173
    
 
174
    It records the bytes sent to it, and replies with a 200.
 
175
    """
 
176
 
 
177
    def __init__(self, expect_body_tail=None):
 
178
        """Constructor.
 
179
 
 
180
        :type expect_body_tail: str
 
181
        :param expect_body_tail: a reply won't be sent until this string is
 
182
            received.
 
183
        """
 
184
        self._expect_body_tail = expect_body_tail
 
185
        self.host = None
 
186
        self.port = None
 
187
        self.received_bytes = ''
 
188
 
 
189
    def setUp(self):
 
190
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
191
        self._sock.bind(('127.0.0.1', 0))
 
192
        self.host, self.port = self._sock.getsockname()
 
193
        self._ready = threading.Event()
 
194
        self._thread = threading.Thread(target=self._accept_read_and_reply)
 
195
        self._thread.setDaemon(True)
 
196
        self._thread.start()
 
197
        self._ready.wait(5)
 
198
 
 
199
    def _accept_read_and_reply(self):
 
200
        self._sock.listen(1)
 
201
        self._ready.set()
 
202
        self._sock.settimeout(5)
 
203
        try:
 
204
            conn, address = self._sock.accept()
 
205
            # On win32, the accepted connection will be non-blocking to start
 
206
            # with because we're using settimeout.
 
207
            conn.setblocking(True)
 
208
            while not self.received_bytes.endswith(self._expect_body_tail):
 
209
                self.received_bytes += conn.recv(4096)
 
210
            conn.sendall('HTTP/1.1 200 OK\r\n')
 
211
        except socket.timeout:
 
212
            # Make sure the client isn't stuck waiting for us to e.g. accept.
 
213
            self._sock.close()
 
214
        except socket.error:
 
215
            # The client may have already closed the socket.
 
216
            pass
 
217
 
 
218
    def tearDown(self):
 
219
        try:
 
220
            self._sock.close()
 
221
        except socket.error:
 
222
            # We might have already closed it.  We don't care.
 
223
            pass
 
224
        self.host = None
 
225
        self.port = None
 
226
 
 
227
 
 
228
class TestHTTPServer(tests.TestCase):
 
229
    """Test the HTTP servers implementations."""
 
230
 
 
231
    def test_invalid_protocol(self):
 
232
        class BogusRequestHandler(http_server.TestingHTTPRequestHandler):
 
233
 
 
234
            protocol_version = 'HTTP/0.1'
 
235
 
 
236
        server = http_server.HttpServer(BogusRequestHandler)
 
237
        try:
 
238
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
239
        except:
 
240
            server.tearDown()
 
241
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
242
 
 
243
    def test_force_invalid_protocol(self):
 
244
        server = http_server.HttpServer(protocol_version='HTTP/0.1')
 
245
        try:
 
246
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
247
        except:
 
248
            server.tearDown()
 
249
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
250
 
 
251
    def test_server_start_and_stop(self):
 
252
        server = http_server.HttpServer()
 
253
        server.setUp()
 
254
        self.assertTrue(server._http_running)
 
255
        server.tearDown()
 
256
        self.assertFalse(server._http_running)
 
257
 
 
258
    def test_create_http_server_one_zero(self):
 
259
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
260
 
 
261
            protocol_version = 'HTTP/1.0'
 
262
 
 
263
        server = http_server.HttpServer(RequestHandlerOneZero)
 
264
        server.setUp()
 
265
        self.addCleanup(server.tearDown)
 
266
        self.assertIsInstance(server._httpd, http_server.TestingHTTPServer)
 
267
 
 
268
    def test_create_http_server_one_one(self):
 
269
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
270
 
 
271
            protocol_version = 'HTTP/1.1'
 
272
 
 
273
        server = http_server.HttpServer(RequestHandlerOneOne)
 
274
        server.setUp()
 
275
        self.addCleanup(server.tearDown)
 
276
        self.assertIsInstance(server._httpd,
 
277
                              http_server.TestingThreadingHTTPServer)
 
278
 
 
279
    def test_create_http_server_force_one_one(self):
 
280
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
281
 
 
282
            protocol_version = 'HTTP/1.0'
 
283
 
 
284
        server = http_server.HttpServer(RequestHandlerOneZero,
 
285
                                        protocol_version='HTTP/1.1')
 
286
        server.setUp()
 
287
        self.addCleanup(server.tearDown)
 
288
        self.assertIsInstance(server._httpd,
 
289
                              http_server.TestingThreadingHTTPServer)
 
290
 
 
291
    def test_create_http_server_force_one_zero(self):
 
292
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
293
 
 
294
            protocol_version = 'HTTP/1.1'
 
295
 
 
296
        server = http_server.HttpServer(RequestHandlerOneOne,
 
297
                                        protocol_version='HTTP/1.0')
 
298
        server.setUp()
 
299
        self.addCleanup(server.tearDown)
 
300
        self.assertIsInstance(server._httpd,
 
301
                              http_server.TestingHTTPServer)
 
302
 
 
303
 
 
304
class TestWithTransport_pycurl(object):
 
305
    """Test case to inherit from if pycurl is present"""
 
306
 
 
307
    def _get_pycurl_maybe(self):
 
308
        try:
 
309
            from bzrlib.transport.http._pycurl import PyCurlTransport
 
310
            return PyCurlTransport
 
311
        except errors.DependencyNotPresent:
 
312
            raise tests.TestSkipped('pycurl not present')
 
313
 
 
314
    _transport = property(_get_pycurl_maybe)
 
315
 
 
316
 
 
317
class TestHttpUrls(tests.TestCase):
 
318
 
 
319
    # TODO: This should be moved to authorization tests once they
 
320
    # are written.
 
321
 
 
322
    def test_url_parsing(self):
 
323
        f = FakeManager()
 
324
        url = http.extract_auth('http://example.com', f)
 
325
        self.assertEquals('http://example.com', url)
 
326
        self.assertEquals(0, len(f.credentials))
 
327
        url = http.extract_auth(
 
328
            'http://user:pass@www.bazaar-vcs.org/bzr/bzr.dev', f)
 
329
        self.assertEquals('http://www.bazaar-vcs.org/bzr/bzr.dev', url)
 
330
        self.assertEquals(1, len(f.credentials))
 
331
        self.assertEquals([None, 'www.bazaar-vcs.org', 'user', 'pass'],
 
332
                          f.credentials[0])
 
333
 
 
334
 
 
335
class TestHttpTransportUrls(tests.TestCase):
 
336
    """Test the http urls."""
 
337
 
 
338
    def test_abs_url(self):
 
339
        """Construction of absolute http URLs"""
 
340
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
341
        eq = self.assertEqualDiff
 
342
        eq(t.abspath('.'), 'http://bazaar-vcs.org/bzr/bzr.dev')
 
343
        eq(t.abspath('foo/bar'), 'http://bazaar-vcs.org/bzr/bzr.dev/foo/bar')
 
344
        eq(t.abspath('.bzr'), 'http://bazaar-vcs.org/bzr/bzr.dev/.bzr')
 
345
        eq(t.abspath('.bzr/1//2/./3'),
 
346
           'http://bazaar-vcs.org/bzr/bzr.dev/.bzr/1/2/3')
 
347
 
 
348
    def test_invalid_http_urls(self):
 
349
        """Trap invalid construction of urls"""
 
350
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
351
        self.assertRaises(errors.InvalidURL,
 
352
                          self._transport,
 
353
                          'http://http://bazaar-vcs.org/bzr/bzr.dev/')
 
354
 
 
355
    def test_http_root_urls(self):
 
356
        """Construction of URLs from server root"""
 
357
        t = self._transport('http://bzr.ozlabs.org/')
 
358
        eq = self.assertEqualDiff
 
359
        eq(t.abspath('.bzr/tree-version'),
 
360
           'http://bzr.ozlabs.org/.bzr/tree-version')
 
361
 
 
362
    def test_http_impl_urls(self):
 
363
        """There are servers which ask for particular clients to connect"""
 
364
        server = self._server()
 
365
        try:
 
366
            server.setUp()
 
367
            url = server.get_url()
 
368
            self.assertTrue(url.startswith('%s://' % self._qualified_prefix))
 
369
        finally:
 
370
            server.tearDown()
 
371
 
 
372
 
 
373
class TestHttps_pycurl(TestWithTransport_pycurl, tests.TestCase):
 
374
 
 
375
    # TODO: This should really be moved into another pycurl
 
376
    # specific test. When https tests will be implemented, take
 
377
    # this one into account.
 
378
    def test_pycurl_without_https_support(self):
 
379
        """Test that pycurl without SSL do not fail with a traceback.
 
380
 
 
381
        For the purpose of the test, we force pycurl to ignore
 
382
        https by supplying a fake version_info that do not
 
383
        support it.
 
384
        """
 
385
        try:
 
386
            import pycurl
 
387
        except ImportError:
 
388
            raise tests.TestSkipped('pycurl not present')
 
389
 
 
390
        version_info_orig = pycurl.version_info
 
391
        try:
 
392
            # Now that we have pycurl imported, we can fake its version_info
 
393
            # This was taken from a windows pycurl without SSL
 
394
            # (thanks to bialix)
 
395
            pycurl.version_info = lambda : (2,
 
396
                                            '7.13.2',
 
397
                                            462082,
 
398
                                            'i386-pc-win32',
 
399
                                            2576,
 
400
                                            None,
 
401
                                            0,
 
402
                                            None,
 
403
                                            ('ftp', 'gopher', 'telnet',
 
404
                                             'dict', 'ldap', 'http', 'file'),
 
405
                                            None,
 
406
                                            0,
 
407
                                            None)
 
408
            self.assertRaises(errors.DependencyNotPresent, self._transport,
 
409
                              'https://launchpad.net')
 
410
        finally:
 
411
            # Restore the right function
 
412
            pycurl.version_info = version_info_orig
 
413
 
 
414
 
 
415
class TestHTTPConnections(http_utils.TestCaseWithWebserver):
 
416
    """Test the http connections."""
 
417
 
 
418
    def setUp(self):
 
419
        http_utils.TestCaseWithWebserver.setUp(self)
 
420
        self.build_tree(['foo/', 'foo/bar'], line_endings='binary',
 
421
                        transport=self.get_transport())
 
422
 
 
423
    def test_http_has(self):
 
424
        server = self.get_readonly_server()
 
425
        t = self._transport(server.get_url())
 
426
        self.assertEqual(t.has('foo/bar'), True)
 
427
        self.assertEqual(len(server.logs), 1)
 
428
        self.assertContainsRe(server.logs[0],
 
429
            r'"HEAD /foo/bar HTTP/1.." (200|302) - "-" "bzr/')
 
430
 
 
431
    def test_http_has_not_found(self):
 
432
        server = self.get_readonly_server()
 
433
        t = self._transport(server.get_url())
 
434
        self.assertEqual(t.has('not-found'), False)
 
435
        self.assertContainsRe(server.logs[1],
 
436
            r'"HEAD /not-found HTTP/1.." 404 - "-" "bzr/')
 
437
 
 
438
    def test_http_get(self):
 
439
        server = self.get_readonly_server()
 
440
        t = self._transport(server.get_url())
 
441
        fp = t.get('foo/bar')
 
442
        self.assertEqualDiff(
 
443
            fp.read(),
 
444
            'contents of foo/bar\n')
 
445
        self.assertEqual(len(server.logs), 1)
 
446
        self.assertTrue(server.logs[0].find(
 
447
            '"GET /foo/bar HTTP/1.1" 200 - "-" "bzr/%s'
 
448
            % bzrlib.__version__) > -1)
 
449
 
 
450
    def test_get_smart_medium(self):
 
451
        # For HTTP, get_smart_medium should return the transport object.
 
452
        server = self.get_readonly_server()
 
453
        http_transport = self._transport(server.get_url())
 
454
        medium = http_transport.get_smart_medium()
 
455
        self.assertIs(medium, http_transport)
 
456
 
 
457
    def test_has_on_bogus_host(self):
 
458
        # Get a free address and don't 'accept' on it, so that we
 
459
        # can be sure there is no http handler there, but set a
 
460
        # reasonable timeout to not slow down tests too much.
 
461
        default_timeout = socket.getdefaulttimeout()
 
462
        try:
 
463
            socket.setdefaulttimeout(2)
 
464
            s = socket.socket()
 
465
            s.bind(('localhost', 0))
 
466
            t = self._transport('http://%s:%s/' % s.getsockname())
 
467
            self.assertRaises(errors.ConnectionError, t.has, 'foo/bar')
 
468
        finally:
 
469
            socket.setdefaulttimeout(default_timeout)
 
470
 
 
471
 
 
472
class TestHttpTransportRegistration(tests.TestCase):
 
473
    """Test registrations of various http implementations"""
 
474
 
 
475
    def test_http_registered(self):
 
476
        t = transport.get_transport('%s://foo.com/' % self._qualified_prefix)
 
477
        self.assertIsInstance(t, transport.Transport)
 
478
        self.assertIsInstance(t, self._transport)
 
479
 
 
480
 
 
481
class TestPost(tests.TestCase):
 
482
 
 
483
    def test_post_body_is_received(self):
 
484
        server = RecordingServer(expect_body_tail='end-of-body')
 
485
        server.setUp()
 
486
        self.addCleanup(server.tearDown)
 
487
        scheme = self._qualified_prefix
 
488
        url = '%s://%s:%s/' % (scheme, server.host, server.port)
 
489
        http_transport = self._transport(url)
 
490
        code, response = http_transport._post('abc def end-of-body')
 
491
        self.assertTrue(
 
492
            server.received_bytes.startswith('POST /.bzr/smart HTTP/1.'))
 
493
        self.assertTrue('content-length: 19\r' in server.received_bytes.lower())
 
494
        # The transport should not be assuming that the server can accept
 
495
        # chunked encoding the first time it connects, because HTTP/1.1, so we
 
496
        # check for the literal string.
 
497
        self.assertTrue(
 
498
            server.received_bytes.endswith('\r\n\r\nabc def end-of-body'))
 
499
 
 
500
 
 
501
class TestRangeHeader(tests.TestCase):
 
502
    """Test range_header method"""
 
503
 
 
504
    def check_header(self, value, ranges=[], tail=0):
 
505
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
506
        coalesce = transport.Transport._coalesce_offsets
 
507
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
508
        range_header = http.HttpTransportBase._range_header
 
509
        self.assertEqual(value, range_header(coalesced, tail))
 
510
 
 
511
    def test_range_header_single(self):
 
512
        self.check_header('0-9', ranges=[(0,9)])
 
513
        self.check_header('100-109', ranges=[(100,109)])
 
514
 
 
515
    def test_range_header_tail(self):
 
516
        self.check_header('-10', tail=10)
 
517
        self.check_header('-50', tail=50)
 
518
 
 
519
    def test_range_header_multi(self):
 
520
        self.check_header('0-9,100-200,300-5000',
 
521
                          ranges=[(0,9), (100, 200), (300,5000)])
 
522
 
 
523
    def test_range_header_mixed(self):
 
524
        self.check_header('0-9,300-5000,-50',
 
525
                          ranges=[(0,9), (300,5000)],
 
526
                          tail=50)
 
527
 
 
528
 
 
529
class TestSpecificRequestHandler(http_utils.TestCaseWithWebserver):
 
530
    """Tests a specific request handler.
 
531
 
 
532
    Daughter classes are expected to override _req_handler_class
 
533
    """
 
534
 
 
535
    # Provide a useful default
 
536
    _req_handler_class = http_server.TestingHTTPRequestHandler
 
537
 
 
538
    def create_transport_readonly_server(self):
 
539
        return http_server.HttpServer(self._req_handler_class,
 
540
                                      protocol_version=self._protocol_version)
 
541
 
 
542
    def _testing_pycurl(self):
 
543
        return pycurl_present and self._transport == PyCurlTransport
 
544
 
 
545
 
 
546
class WallRequestHandler(http_server.TestingHTTPRequestHandler):
 
547
    """Whatever request comes in, close the connection"""
 
548
 
 
549
    def handle_one_request(self):
 
550
        """Handle a single HTTP request, by abruptly closing the connection"""
 
551
        self.close_connection = 1
 
552
 
 
553
 
 
554
class TestWallServer(TestSpecificRequestHandler):
 
555
    """Tests exceptions during the connection phase"""
 
556
 
 
557
    _req_handler_class = WallRequestHandler
 
558
 
 
559
    def test_http_has(self):
 
560
        server = self.get_readonly_server()
 
561
        t = self._transport(server.get_url())
 
562
        # Unfortunately httplib (see HTTPResponse._read_status
 
563
        # for details) make no distinction between a closed
 
564
        # socket and badly formatted status line, so we can't
 
565
        # just test for ConnectionError, we have to test
 
566
        # InvalidHttpResponse too.
 
567
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
568
                          t.has, 'foo/bar')
 
569
 
 
570
    def test_http_get(self):
 
571
        server = self.get_readonly_server()
 
572
        t = self._transport(server.get_url())
 
573
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
574
                          t.get, 'foo/bar')
 
575
 
 
576
 
 
577
class BadStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
578
    """Whatever request comes in, returns a bad status"""
 
579
 
 
580
    def parse_request(self):
 
581
        """Fakes handling a single HTTP request, returns a bad status"""
 
582
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
583
        self.send_response(0, "Bad status")
 
584
        self.close_connection = 1
 
585
        return False
 
586
 
 
587
 
 
588
class TestBadStatusServer(TestSpecificRequestHandler):
 
589
    """Tests bad status from server."""
 
590
 
 
591
    _req_handler_class = BadStatusRequestHandler
 
592
 
 
593
    def test_http_has(self):
 
594
        server = self.get_readonly_server()
 
595
        t = self._transport(server.get_url())
 
596
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
597
 
 
598
    def test_http_get(self):
 
599
        server = self.get_readonly_server()
 
600
        t = self._transport(server.get_url())
 
601
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
602
 
 
603
 
 
604
class InvalidStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
605
    """Whatever request comes in, returns an invalid status"""
 
606
 
 
607
    def parse_request(self):
 
608
        """Fakes handling a single HTTP request, returns a bad status"""
 
609
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
610
        self.wfile.write("Invalid status line\r\n")
 
611
        return False
 
612
 
 
613
 
 
614
class TestInvalidStatusServer(TestBadStatusServer):
 
615
    """Tests invalid status from server.
 
616
 
 
617
    Both implementations raises the same error as for a bad status.
 
618
    """
 
619
 
 
620
    _req_handler_class = InvalidStatusRequestHandler
 
621
 
 
622
    def test_http_has(self):
 
623
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
624
            raise tests.KnownFailure(
 
625
                'pycurl hangs if the server send back garbage')
 
626
        super(TestInvalidStatusServer, self).test_http_has()
 
627
 
 
628
    def test_http_get(self):
 
629
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
630
            raise tests.KnownFailure(
 
631
                'pycurl hangs if the server send back garbage')
 
632
        super(TestInvalidStatusServer, self).test_http_get()
 
633
 
 
634
 
 
635
class BadProtocolRequestHandler(http_server.TestingHTTPRequestHandler):
 
636
    """Whatever request comes in, returns a bad protocol version"""
 
637
 
 
638
    def parse_request(self):
 
639
        """Fakes handling a single HTTP request, returns a bad status"""
 
640
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
641
        # Returns an invalid protocol version, but curl just
 
642
        # ignores it and those cannot be tested.
 
643
        self.wfile.write("%s %d %s\r\n" % ('HTTP/0.0',
 
644
                                           404,
 
645
                                           'Look at my protocol version'))
 
646
        return False
 
647
 
 
648
 
 
649
class TestBadProtocolServer(TestSpecificRequestHandler):
 
650
    """Tests bad protocol from server."""
 
651
 
 
652
    _req_handler_class = BadProtocolRequestHandler
 
653
 
 
654
    def setUp(self):
 
655
        if pycurl_present and self._transport == PyCurlTransport:
 
656
            raise tests.TestNotApplicable(
 
657
                "pycurl doesn't check the protocol version")
 
658
        super(TestBadProtocolServer, self).setUp()
 
659
 
 
660
    def test_http_has(self):
 
661
        server = self.get_readonly_server()
 
662
        t = self._transport(server.get_url())
 
663
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
664
 
 
665
    def test_http_get(self):
 
666
        server = self.get_readonly_server()
 
667
        t = self._transport(server.get_url())
 
668
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
669
 
 
670
 
 
671
class ForbiddenRequestHandler(http_server.TestingHTTPRequestHandler):
 
672
    """Whatever request comes in, returns a 403 code"""
 
673
 
 
674
    def parse_request(self):
 
675
        """Handle a single HTTP request, by replying we cannot handle it"""
 
676
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
677
        self.send_error(403)
 
678
        return False
 
679
 
 
680
 
 
681
class TestForbiddenServer(TestSpecificRequestHandler):
 
682
    """Tests forbidden server"""
 
683
 
 
684
    _req_handler_class = ForbiddenRequestHandler
 
685
 
 
686
    def test_http_has(self):
 
687
        server = self.get_readonly_server()
 
688
        t = self._transport(server.get_url())
 
689
        self.assertRaises(errors.TransportError, t.has, 'foo/bar')
 
690
 
 
691
    def test_http_get(self):
 
692
        server = self.get_readonly_server()
 
693
        t = self._transport(server.get_url())
 
694
        self.assertRaises(errors.TransportError, t.get, 'foo/bar')
 
695
 
 
696
 
 
697
class TestRecordingServer(tests.TestCase):
 
698
 
 
699
    def test_create(self):
 
700
        server = RecordingServer(expect_body_tail=None)
 
701
        self.assertEqual('', server.received_bytes)
 
702
        self.assertEqual(None, server.host)
 
703
        self.assertEqual(None, server.port)
 
704
 
 
705
    def test_setUp_and_tearDown(self):
 
706
        server = RecordingServer(expect_body_tail=None)
 
707
        server.setUp()
 
708
        try:
 
709
            self.assertNotEqual(None, server.host)
 
710
            self.assertNotEqual(None, server.port)
 
711
        finally:
 
712
            server.tearDown()
 
713
        self.assertEqual(None, server.host)
 
714
        self.assertEqual(None, server.port)
 
715
 
 
716
    def test_send_receive_bytes(self):
 
717
        server = RecordingServer(expect_body_tail='c')
 
718
        server.setUp()
 
719
        self.addCleanup(server.tearDown)
 
720
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
721
        sock.connect((server.host, server.port))
 
722
        sock.sendall('abc')
 
723
        self.assertEqual('HTTP/1.1 200 OK\r\n',
 
724
                         osutils.recv_all(sock, 4096))
 
725
        self.assertEqual('abc', server.received_bytes)
 
726
 
 
727
 
 
728
class TestRangeRequestServer(TestSpecificRequestHandler):
 
729
    """Tests readv requests against server.
 
730
 
 
731
    We test against default "normal" server.
 
732
    """
 
733
 
 
734
    def setUp(self):
 
735
        super(TestRangeRequestServer, self).setUp()
 
736
        self.build_tree_contents([('a', '0123456789')],)
 
737
 
 
738
    def test_readv(self):
 
739
        server = self.get_readonly_server()
 
740
        t = self._transport(server.get_url())
 
741
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
742
        self.assertEqual(l[0], (0, '0'))
 
743
        self.assertEqual(l[1], (1, '1'))
 
744
        self.assertEqual(l[2], (3, '34'))
 
745
        self.assertEqual(l[3], (9, '9'))
 
746
 
 
747
    def test_readv_out_of_order(self):
 
748
        server = self.get_readonly_server()
 
749
        t = self._transport(server.get_url())
 
750
        l = list(t.readv('a', ((1, 1), (9, 1), (0, 1), (3, 2))))
 
751
        self.assertEqual(l[0], (1, '1'))
 
752
        self.assertEqual(l[1], (9, '9'))
 
753
        self.assertEqual(l[2], (0, '0'))
 
754
        self.assertEqual(l[3], (3, '34'))
 
755
 
 
756
    def test_readv_invalid_ranges(self):
 
757
        server = self.get_readonly_server()
 
758
        t = self._transport(server.get_url())
 
759
 
 
760
        # This is intentionally reading off the end of the file
 
761
        # since we are sure that it cannot get there
 
762
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
763
                              t.readv, 'a', [(1,1), (8,10)])
 
764
 
 
765
        # This is trying to seek past the end of the file, it should
 
766
        # also raise a special error
 
767
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
768
                              t.readv, 'a', [(12,2)])
 
769
 
 
770
    def test_readv_multiple_get_requests(self):
 
771
        server = self.get_readonly_server()
 
772
        t = self._transport(server.get_url())
 
773
        # force transport to issue multiple requests
 
774
        t._max_readv_combine = 1
 
775
        t._max_get_ranges = 1
 
776
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
777
        self.assertEqual(l[0], (0, '0'))
 
778
        self.assertEqual(l[1], (1, '1'))
 
779
        self.assertEqual(l[2], (3, '34'))
 
780
        self.assertEqual(l[3], (9, '9'))
 
781
        # The server should have issued 4 requests
 
782
        self.assertEqual(4, server.GET_request_nb)
 
783
 
 
784
    def test_readv_get_max_size(self):
 
785
        server = self.get_readonly_server()
 
786
        t = self._transport(server.get_url())
 
787
        # force transport to issue multiple requests by limiting the number of
 
788
        # bytes by request. Note that this apply to coalesced offsets only, a
 
789
        # single range will keep its size even if bigger than the limit.
 
790
        t._get_max_size = 2
 
791
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
792
        self.assertEqual(l[0], (0, '0'))
 
793
        self.assertEqual(l[1], (1, '1'))
 
794
        self.assertEqual(l[2], (2, '2345'))
 
795
        self.assertEqual(l[3], (6, '6789'))
 
796
        # The server should have issued 3 requests
 
797
        self.assertEqual(3, server.GET_request_nb)
 
798
 
 
799
    def test_complete_readv_leave_pipe_clean(self):
 
800
        server = self.get_readonly_server()
 
801
        t = self._transport(server.get_url())
 
802
        # force transport to issue multiple requests
 
803
        t._get_max_size = 2
 
804
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
805
        # The server should have issued 3 requests
 
806
        self.assertEqual(3, server.GET_request_nb)
 
807
        self.assertEqual('0123456789', t.get_bytes('a'))
 
808
        self.assertEqual(4, server.GET_request_nb)
 
809
 
 
810
    def test_incomplete_readv_leave_pipe_clean(self):
 
811
        server = self.get_readonly_server()
 
812
        t = self._transport(server.get_url())
 
813
        # force transport to issue multiple requests
 
814
        t._get_max_size = 2
 
815
        # Don't collapse readv results into a list so that we leave unread
 
816
        # bytes on the socket
 
817
        ireadv = iter(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
818
        self.assertEqual((0, '0'), ireadv.next())
 
819
        # The server should have issued one request so far 
 
820
        self.assertEqual(1, server.GET_request_nb)
 
821
        self.assertEqual('0123456789', t.get_bytes('a'))
 
822
        # get_bytes issued an additional request, the readv pending ones are
 
823
        # lost
 
824
        self.assertEqual(2, server.GET_request_nb)
 
825
 
 
826
 
 
827
class SingleRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
828
    """Always reply to range request as if they were single.
 
829
 
 
830
    Don't be explicit about it, just to annoy the clients.
 
831
    """
 
832
 
 
833
    def get_multiple_ranges(self, file, file_size, ranges):
 
834
        """Answer as if it was a single range request and ignores the rest"""
 
835
        (start, end) = ranges[0]
 
836
        return self.get_single_range(file, file_size, start, end)
 
837
 
 
838
 
 
839
class TestSingleRangeRequestServer(TestRangeRequestServer):
 
840
    """Test readv against a server which accept only single range requests"""
 
841
 
 
842
    _req_handler_class = SingleRangeRequestHandler
 
843
 
 
844
 
 
845
class SingleOnlyRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
846
    """Only reply to simple range requests, errors out on multiple"""
 
847
 
 
848
    def get_multiple_ranges(self, file, file_size, ranges):
 
849
        """Refuses the multiple ranges request"""
 
850
        if len(ranges) > 1:
 
851
            file.close()
 
852
            self.send_error(416, "Requested range not satisfiable")
 
853
            return
 
854
        (start, end) = ranges[0]
 
855
        return self.get_single_range(file, file_size, start, end)
 
856
 
 
857
 
 
858
class TestSingleOnlyRangeRequestServer(TestRangeRequestServer):
 
859
    """Test readv against a server which only accept single range requests"""
 
860
 
 
861
    _req_handler_class = SingleOnlyRangeRequestHandler
 
862
 
 
863
 
 
864
class NoRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
865
    """Ignore range requests without notice"""
 
866
 
 
867
    def do_GET(self):
 
868
        # Update the statistics
 
869
        self.server.test_case_server.GET_request_nb += 1
 
870
        # Just bypass the range handling done by TestingHTTPRequestHandler
 
871
        return SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
 
872
 
 
873
 
 
874
class TestNoRangeRequestServer(TestRangeRequestServer):
 
875
    """Test readv against a server which do not accept range requests"""
 
876
 
 
877
    _req_handler_class = NoRangeRequestHandler
 
878
 
 
879
 
 
880
class MultipleRangeWithoutContentLengthRequestHandler(
 
881
    http_server.TestingHTTPRequestHandler):
 
882
    """Reply to multiple range requests without content length header."""
 
883
 
 
884
    def get_multiple_ranges(self, file, file_size, ranges):
 
885
        self.send_response(206)
 
886
        self.send_header('Accept-Ranges', 'bytes')
 
887
        boundary = "%d" % random.randint(0,0x7FFFFFFF)
 
888
        self.send_header("Content-Type",
 
889
                         "multipart/byteranges; boundary=%s" % boundary)
 
890
        self.end_headers()
 
891
        for (start, end) in ranges:
 
892
            self.wfile.write("--%s\r\n" % boundary)
 
893
            self.send_header("Content-type", 'application/octet-stream')
 
894
            self.send_header("Content-Range", "bytes %d-%d/%d" % (start,
 
895
                                                                  end,
 
896
                                                                  file_size))
 
897
            self.end_headers()
 
898
            self.send_range_content(file, start, end - start + 1)
 
899
        # Final boundary
 
900
        self.wfile.write("--%s\r\n" % boundary)
 
901
 
 
902
 
 
903
class TestMultipleRangeWithoutContentLengthServer(TestRangeRequestServer):
 
904
 
 
905
    _req_handler_class = MultipleRangeWithoutContentLengthRequestHandler
 
906
 
 
907
 
 
908
class TruncatedMultipleRangeRequestHandler(
 
909
    http_server.TestingHTTPRequestHandler):
 
910
    """Reply to multiple range requests truncating the last ones.
 
911
 
 
912
    This server generates responses whose Content-Length describes all the
 
913
    ranges, but fail to include the last ones leading to client short reads.
 
914
    This has been observed randomly with lighttpd (bug #179368).
 
915
    """
 
916
 
 
917
    _truncated_ranges = 2
 
918
 
 
919
    def get_multiple_ranges(self, file, file_size, ranges):
 
920
        self.send_response(206)
 
921
        self.send_header('Accept-Ranges', 'bytes')
 
922
        boundary = 'tagada'
 
923
        self.send_header('Content-Type',
 
924
                         'multipart/byteranges; boundary=%s' % boundary)
 
925
        boundary_line = '--%s\r\n' % boundary
 
926
        # Calculate the Content-Length
 
927
        content_length = 0
 
928
        for (start, end) in ranges:
 
929
            content_length += len(boundary_line)
 
930
            content_length += self._header_line_length(
 
931
                'Content-type', 'application/octet-stream')
 
932
            content_length += self._header_line_length(
 
933
                'Content-Range', 'bytes %d-%d/%d' % (start, end, file_size))
 
934
            content_length += len('\r\n') # end headers
 
935
            content_length += end - start # + 1
 
936
        content_length += len(boundary_line)
 
937
        self.send_header('Content-length', content_length)
 
938
        self.end_headers()
 
939
 
 
940
        # Send the multipart body
 
941
        cur = 0
 
942
        for (start, end) in ranges:
 
943
            self.wfile.write(boundary_line)
 
944
            self.send_header('Content-type', 'application/octet-stream')
 
945
            self.send_header('Content-Range', 'bytes %d-%d/%d'
 
946
                             % (start, end, file_size))
 
947
            self.end_headers()
 
948
            if cur + self._truncated_ranges >= len(ranges):
 
949
                # Abruptly ends the response and close the connection
 
950
                self.close_connection = 1
 
951
                return
 
952
            self.send_range_content(file, start, end - start + 1)
 
953
            cur += 1
 
954
        # No final boundary
 
955
        self.wfile.write(boundary_line)
 
956
 
 
957
 
 
958
class TestTruncatedMultipleRangeServer(TestSpecificRequestHandler):
 
959
 
 
960
    _req_handler_class = TruncatedMultipleRangeRequestHandler
 
961
 
 
962
    def setUp(self):
 
963
        super(TestTruncatedMultipleRangeServer, self).setUp()
 
964
        self.build_tree_contents([('a', '0123456789')],)
 
965
 
 
966
    def test_readv_with_short_reads(self):
 
967
        server = self.get_readonly_server()
 
968
        t = self._transport(server.get_url())
 
969
        # Force separate ranges for each offset
 
970
        t._bytes_to_read_before_seek = 0
 
971
        ireadv = iter(t.readv('a', ((0, 1), (2, 1), (4, 2), (9, 1))))
 
972
        self.assertEqual((0, '0'), ireadv.next())
 
973
        self.assertEqual((2, '2'), ireadv.next())
 
974
        if not self._testing_pycurl():
 
975
            # Only one request have been issued so far (except for pycurl that
 
976
            # try to read the whole response at once)
 
977
            self.assertEqual(1, server.GET_request_nb)
 
978
        self.assertEqual((4, '45'), ireadv.next())
 
979
        self.assertEqual((9, '9'), ireadv.next())
 
980
        # Both implementations issue 3 requests but:
 
981
        # - urllib does two multiple (4 ranges, then 2 ranges) then a single
 
982
        #   range,
 
983
        # - pycurl does two multiple (4 ranges, 4 ranges) then a single range
 
984
        self.assertEqual(3, server.GET_request_nb)
 
985
        # Finally the client have tried a single range request and stays in
 
986
        # that mode
 
987
        self.assertEqual('single', t._range_hint)
 
988
 
 
989
class LimitedRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
990
    """Errors out when range specifiers exceed the limit"""
 
991
 
 
992
    def get_multiple_ranges(self, file, file_size, ranges):
 
993
        """Refuses the multiple ranges request"""
 
994
        tcs = self.server.test_case_server
 
995
        if tcs.range_limit is not None and len(ranges) > tcs.range_limit:
 
996
            file.close()
 
997
            # Emulate apache behavior
 
998
            self.send_error(400, "Bad Request")
 
999
            return
 
1000
        return http_server.TestingHTTPRequestHandler.get_multiple_ranges(
 
1001
            self, file, file_size, ranges)
 
1002
 
 
1003
 
 
1004
class LimitedRangeHTTPServer(http_server.HttpServer):
 
1005
    """An HttpServer erroring out on requests with too much range specifiers"""
 
1006
 
 
1007
    def __init__(self, request_handler=LimitedRangeRequestHandler,
 
1008
                 protocol_version=None,
 
1009
                 range_limit=None):
 
1010
        http_server.HttpServer.__init__(self, request_handler,
 
1011
                                        protocol_version=protocol_version)
 
1012
        self.range_limit = range_limit
 
1013
 
 
1014
 
 
1015
class TestLimitedRangeRequestServer(http_utils.TestCaseWithWebserver):
 
1016
    """Tests readv requests against a server erroring out on too much ranges."""
 
1017
 
 
1018
    # Requests with more range specifiers will error out
 
1019
    range_limit = 3
 
1020
 
 
1021
    def create_transport_readonly_server(self):
 
1022
        return LimitedRangeHTTPServer(range_limit=self.range_limit,
 
1023
                                      protocol_version=self._protocol_version)
 
1024
 
 
1025
    def get_transport(self):
 
1026
        return self._transport(self.get_readonly_server().get_url())
 
1027
 
 
1028
    def setUp(self):
 
1029
        http_utils.TestCaseWithWebserver.setUp(self)
 
1030
        # We need to manipulate ranges that correspond to real chunks in the
 
1031
        # response, so we build a content appropriately.
 
1032
        filler = ''.join(['abcdefghij' for x in range(102)])
 
1033
        content = ''.join(['%04d' % v + filler for v in range(16)])
 
1034
        self.build_tree_contents([('a', content)],)
 
1035
 
 
1036
    def test_few_ranges(self):
 
1037
        t = self.get_transport()
 
1038
        l = list(t.readv('a', ((0, 4), (1024, 4), )))
 
1039
        self.assertEqual(l[0], (0, '0000'))
 
1040
        self.assertEqual(l[1], (1024, '0001'))
 
1041
        self.assertEqual(1, self.get_readonly_server().GET_request_nb)
 
1042
 
 
1043
    def test_more_ranges(self):
 
1044
        t = self.get_transport()
 
1045
        l = list(t.readv('a', ((0, 4), (1024, 4), (4096, 4), (8192, 4))))
 
1046
        self.assertEqual(l[0], (0, '0000'))
 
1047
        self.assertEqual(l[1], (1024, '0001'))
 
1048
        self.assertEqual(l[2], (4096, '0004'))
 
1049
        self.assertEqual(l[3], (8192, '0008'))
 
1050
        # The server will refuse to serve the first request (too much ranges),
 
1051
        # a second request will succeeds.
 
1052
        self.assertEqual(2, self.get_readonly_server().GET_request_nb)
 
1053
 
 
1054
 
 
1055
class TestHttpProxyWhiteBox(tests.TestCase):
 
1056
    """Whitebox test proxy http authorization.
 
1057
 
 
1058
    Only the urllib implementation is tested here.
 
1059
    """
 
1060
 
 
1061
    def setUp(self):
 
1062
        tests.TestCase.setUp(self)
 
1063
        self._old_env = {}
 
1064
 
 
1065
    def tearDown(self):
 
1066
        self._restore_env()
 
1067
 
 
1068
    def _install_env(self, env):
 
1069
        for name, value in env.iteritems():
 
1070
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1071
 
 
1072
    def _restore_env(self):
 
1073
        for name, value in self._old_env.iteritems():
 
1074
            osutils.set_or_unset_env(name, value)
 
1075
 
 
1076
    def _proxied_request(self):
 
1077
        handler = _urllib2_wrappers.ProxyHandler()
 
1078
        request = _urllib2_wrappers.Request('GET','http://baz/buzzle')
 
1079
        handler.set_proxy(request, 'http')
 
1080
        return request
 
1081
 
 
1082
    def test_empty_user(self):
 
1083
        self._install_env({'http_proxy': 'http://bar.com'})
 
1084
        request = self._proxied_request()
 
1085
        self.assertFalse(request.headers.has_key('Proxy-authorization'))
 
1086
 
 
1087
    def test_invalid_proxy(self):
 
1088
        """A proxy env variable without scheme"""
 
1089
        self._install_env({'http_proxy': 'host:1234'})
 
1090
        self.assertRaises(errors.InvalidURL, self._proxied_request)
 
1091
 
 
1092
 
 
1093
class TestProxyHttpServer(http_utils.TestCaseWithTwoWebservers):
 
1094
    """Tests proxy server.
 
1095
 
 
1096
    Be aware that we do not setup a real proxy here. Instead, we
 
1097
    check that the *connection* goes through the proxy by serving
 
1098
    different content (the faked proxy server append '-proxied'
 
1099
    to the file names).
 
1100
    """
 
1101
 
 
1102
    # FIXME: We don't have an https server available, so we don't
 
1103
    # test https connections.
 
1104
 
 
1105
    def setUp(self):
 
1106
        super(TestProxyHttpServer, self).setUp()
 
1107
        self.build_tree_contents([('foo', 'contents of foo\n'),
 
1108
                                  ('foo-proxied', 'proxied contents of foo\n')])
 
1109
        # Let's setup some attributes for tests
 
1110
        self.server = self.get_readonly_server()
 
1111
        self.proxy_address = '%s:%d' % (self.server.host, self.server.port)
 
1112
        if self._testing_pycurl():
 
1113
            # Oh my ! pycurl does not check for the port as part of
 
1114
            # no_proxy :-( So we just test the host part
 
1115
            self.no_proxy_host = 'localhost'
 
1116
        else:
 
1117
            self.no_proxy_host = self.proxy_address
 
1118
        # The secondary server is the proxy
 
1119
        self.proxy = self.get_secondary_server()
 
1120
        self.proxy_url = self.proxy.get_url()
 
1121
        self._old_env = {}
 
1122
 
 
1123
    def _testing_pycurl(self):
 
1124
        return pycurl_present and self._transport == PyCurlTransport
 
1125
 
 
1126
    def create_transport_secondary_server(self):
 
1127
        """Creates an http server that will serve files with
 
1128
        '-proxied' appended to their names.
 
1129
        """
 
1130
        return http_utils.ProxyServer(protocol_version=self._protocol_version)
 
1131
 
 
1132
    def _install_env(self, env):
 
1133
        for name, value in env.iteritems():
 
1134
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1135
 
 
1136
    def _restore_env(self):
 
1137
        for name, value in self._old_env.iteritems():
 
1138
            osutils.set_or_unset_env(name, value)
 
1139
 
 
1140
    def proxied_in_env(self, env):
 
1141
        self._install_env(env)
 
1142
        url = self.server.get_url()
 
1143
        t = self._transport(url)
 
1144
        try:
 
1145
            self.assertEqual(t.get('foo').read(), 'proxied contents of foo\n')
 
1146
        finally:
 
1147
            self._restore_env()
 
1148
 
 
1149
    def not_proxied_in_env(self, env):
 
1150
        self._install_env(env)
 
1151
        url = self.server.get_url()
 
1152
        t = self._transport(url)
 
1153
        try:
 
1154
            self.assertEqual(t.get('foo').read(), 'contents of foo\n')
 
1155
        finally:
 
1156
            self._restore_env()
 
1157
 
 
1158
    def test_http_proxy(self):
 
1159
        self.proxied_in_env({'http_proxy': self.proxy_url})
 
1160
 
 
1161
    def test_HTTP_PROXY(self):
 
1162
        if self._testing_pycurl():
 
1163
            # pycurl does not check HTTP_PROXY for security reasons
 
1164
            # (for use in a CGI context that we do not care
 
1165
            # about. Should we ?)
 
1166
            raise tests.TestNotApplicable(
 
1167
                'pycurl does not check HTTP_PROXY for security reasons')
 
1168
        self.proxied_in_env({'HTTP_PROXY': self.proxy_url})
 
1169
 
 
1170
    def test_all_proxy(self):
 
1171
        self.proxied_in_env({'all_proxy': self.proxy_url})
 
1172
 
 
1173
    def test_ALL_PROXY(self):
 
1174
        self.proxied_in_env({'ALL_PROXY': self.proxy_url})
 
1175
 
 
1176
    def test_http_proxy_with_no_proxy(self):
 
1177
        self.not_proxied_in_env({'http_proxy': self.proxy_url,
 
1178
                                 'no_proxy': self.no_proxy_host})
 
1179
 
 
1180
    def test_HTTP_PROXY_with_NO_PROXY(self):
 
1181
        if self._testing_pycurl():
 
1182
            raise tests.TestNotApplicable(
 
1183
                'pycurl does not check HTTP_PROXY for security reasons')
 
1184
        self.not_proxied_in_env({'HTTP_PROXY': self.proxy_url,
 
1185
                                 'NO_PROXY': self.no_proxy_host})
 
1186
 
 
1187
    def test_all_proxy_with_no_proxy(self):
 
1188
        self.not_proxied_in_env({'all_proxy': self.proxy_url,
 
1189
                                 'no_proxy': self.no_proxy_host})
 
1190
 
 
1191
    def test_ALL_PROXY_with_NO_PROXY(self):
 
1192
        self.not_proxied_in_env({'ALL_PROXY': self.proxy_url,
 
1193
                                 'NO_PROXY': self.no_proxy_host})
 
1194
 
 
1195
    def test_http_proxy_without_scheme(self):
 
1196
        if self._testing_pycurl():
 
1197
            # pycurl *ignores* invalid proxy env variables. If that ever change
 
1198
            # in the future, this test will fail indicating that pycurl do not
 
1199
            # ignore anymore such variables.
 
1200
            self.not_proxied_in_env({'http_proxy': self.proxy_address})
 
1201
        else:
 
1202
            self.assertRaises(errors.InvalidURL,
 
1203
                              self.proxied_in_env,
 
1204
                              {'http_proxy': self.proxy_address})
 
1205
 
 
1206
 
 
1207
class TestRanges(http_utils.TestCaseWithWebserver):
 
1208
    """Test the Range header in GET methods."""
 
1209
 
 
1210
    def setUp(self):
 
1211
        http_utils.TestCaseWithWebserver.setUp(self)
 
1212
        self.build_tree_contents([('a', '0123456789')],)
 
1213
        server = self.get_readonly_server()
 
1214
        self.transport = self._transport(server.get_url())
 
1215
 
 
1216
    def create_transport_readonly_server(self):
 
1217
        return http_server.HttpServer(protocol_version=self._protocol_version)
 
1218
 
 
1219
    def _file_contents(self, relpath, ranges):
 
1220
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
1221
        coalesce = self.transport._coalesce_offsets
 
1222
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
1223
        code, data = self.transport._get(relpath, coalesced)
 
1224
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1225
        for start, end in ranges:
 
1226
            data.seek(start)
 
1227
            yield data.read(end - start + 1)
 
1228
 
 
1229
    def _file_tail(self, relpath, tail_amount):
 
1230
        code, data = self.transport._get(relpath, [], tail_amount)
 
1231
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1232
        data.seek(-tail_amount, 2)
 
1233
        return data.read(tail_amount)
 
1234
 
 
1235
    def test_range_header(self):
 
1236
        # Valid ranges
 
1237
        map(self.assertEqual,['0', '234'],
 
1238
            list(self._file_contents('a', [(0,0), (2,4)])),)
 
1239
 
 
1240
    def test_range_header_tail(self):
 
1241
        self.assertEqual('789', self._file_tail('a', 3))
 
1242
 
 
1243
    def test_syntactically_invalid_range_header(self):
 
1244
        self.assertListRaises(errors.InvalidHttpRange,
 
1245
                          self._file_contents, 'a', [(4, 3)])
 
1246
 
 
1247
    def test_semantically_invalid_range_header(self):
 
1248
        self.assertListRaises(errors.InvalidHttpRange,
 
1249
                          self._file_contents, 'a', [(42, 128)])
 
1250
 
 
1251
 
 
1252
class TestHTTPRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1253
    """Test redirection between http servers."""
 
1254
 
 
1255
    def create_transport_secondary_server(self):
 
1256
        """Create the secondary server redirecting to the primary server"""
 
1257
        new = self.get_readonly_server()
 
1258
 
 
1259
        redirecting = http_utils.HTTPServerRedirecting(
 
1260
            protocol_version=self._protocol_version)
 
1261
        redirecting.redirect_to(new.host, new.port)
 
1262
        return redirecting
 
1263
 
 
1264
    def setUp(self):
 
1265
        super(TestHTTPRedirections, self).setUp()
 
1266
        self.build_tree_contents([('a', '0123456789'),
 
1267
                                  ('bundle',
 
1268
                                  '# Bazaar revision bundle v0.9\n#\n')
 
1269
                                  ],)
 
1270
 
 
1271
        self.old_transport = self._transport(self.old_server.get_url())
 
1272
 
 
1273
    def test_redirected(self):
 
1274
        self.assertRaises(errors.RedirectRequested, self.old_transport.get, 'a')
 
1275
        t = self._transport(self.new_server.get_url())
 
1276
        self.assertEqual('0123456789', t.get('a').read())
 
1277
 
 
1278
    def test_read_redirected_bundle_from_url(self):
 
1279
        from bzrlib.bundle import read_bundle_from_url
 
1280
        url = self.old_transport.abspath('bundle')
 
1281
        bundle = read_bundle_from_url(url)
 
1282
        # If read_bundle_from_url was successful we get an empty bundle
 
1283
        self.assertEqual([], bundle.revisions)
 
1284
 
 
1285
 
 
1286
class RedirectedRequest(_urllib2_wrappers.Request):
 
1287
    """Request following redirections. """
 
1288
 
 
1289
    init_orig = _urllib2_wrappers.Request.__init__
 
1290
 
 
1291
    def __init__(self, method, url, *args, **kwargs):
 
1292
        """Constructor.
 
1293
 
 
1294
        """
 
1295
        # Since the tests using this class will replace
 
1296
        # _urllib2_wrappers.Request, we can't just call the base class __init__
 
1297
        # or we'll loop.
 
1298
        RedirectedRequest.init_orig(self, method, url, args, kwargs)
 
1299
        self.follow_redirections = True
 
1300
 
 
1301
 
 
1302
class TestHTTPSilentRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1303
    """Test redirections.
 
1304
 
 
1305
    http implementations do not redirect silently anymore (they
 
1306
    do not redirect at all in fact). The mechanism is still in
 
1307
    place at the _urllib2_wrappers.Request level and these tests
 
1308
    exercise it.
 
1309
 
 
1310
    For the pycurl implementation
 
1311
    the redirection have been deleted as we may deprecate pycurl
 
1312
    and I have no place to keep a working implementation.
 
1313
    -- vila 20070212
 
1314
    """
 
1315
 
 
1316
    def setUp(self):
 
1317
        if pycurl_present and self._transport == PyCurlTransport:
 
1318
            raise tests.TestNotApplicable(
 
1319
                "pycurl doesn't redirect silently annymore")
 
1320
        super(TestHTTPSilentRedirections, self).setUp()
 
1321
        self.setup_redirected_request()
 
1322
        self.addCleanup(self.cleanup_redirected_request)
 
1323
        self.build_tree_contents([('a','a'),
 
1324
                                  ('1/',),
 
1325
                                  ('1/a', 'redirected once'),
 
1326
                                  ('2/',),
 
1327
                                  ('2/a', 'redirected twice'),
 
1328
                                  ('3/',),
 
1329
                                  ('3/a', 'redirected thrice'),
 
1330
                                  ('4/',),
 
1331
                                  ('4/a', 'redirected 4 times'),
 
1332
                                  ('5/',),
 
1333
                                  ('5/a', 'redirected 5 times'),
 
1334
                                  ],)
 
1335
 
 
1336
        self.old_transport = self._transport(self.old_server.get_url())
 
1337
 
 
1338
    def setup_redirected_request(self):
 
1339
        self.original_class = _urllib2_wrappers.Request
 
1340
        _urllib2_wrappers.Request = RedirectedRequest
 
1341
 
 
1342
    def cleanup_redirected_request(self):
 
1343
        _urllib2_wrappers.Request = self.original_class
 
1344
 
 
1345
    def create_transport_secondary_server(self):
 
1346
        """Create the secondary server, redirections are defined in the tests"""
 
1347
        return http_utils.HTTPServerRedirecting(
 
1348
            protocol_version=self._protocol_version)
 
1349
 
 
1350
    def test_one_redirection(self):
 
1351
        t = self.old_transport
 
1352
 
 
1353
        req = RedirectedRequest('GET', t.abspath('a'))
 
1354
        req.follow_redirections = True
 
1355
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1356
                                       self.new_server.port)
 
1357
        self.old_server.redirections = \
 
1358
            [('(.*)', r'%s/1\1' % (new_prefix), 301),]
 
1359
        self.assertEquals('redirected once',t._perform(req).read())
 
1360
 
 
1361
    def test_five_redirections(self):
 
1362
        t = self.old_transport
 
1363
 
 
1364
        req = RedirectedRequest('GET', t.abspath('a'))
 
1365
        req.follow_redirections = True
 
1366
        old_prefix = 'http://%s:%s' % (self.old_server.host,
 
1367
                                       self.old_server.port)
 
1368
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1369
                                       self.new_server.port)
 
1370
        self.old_server.redirections = [
 
1371
            ('/1(.*)', r'%s/2\1' % (old_prefix), 302),
 
1372
            ('/2(.*)', r'%s/3\1' % (old_prefix), 303),
 
1373
            ('/3(.*)', r'%s/4\1' % (old_prefix), 307),
 
1374
            ('/4(.*)', r'%s/5\1' % (new_prefix), 301),
 
1375
            ('(/[^/]+)', r'%s/1\1' % (old_prefix), 301),
 
1376
            ]
 
1377
        self.assertEquals('redirected 5 times',t._perform(req).read())
 
1378
 
 
1379
 
 
1380
class TestDoCatchRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1381
    """Test transport.do_catching_redirections."""
 
1382
 
 
1383
    def setUp(self):
 
1384
        super(TestDoCatchRedirections, self).setUp()
 
1385
        self.build_tree_contents([('a', '0123456789'),],)
 
1386
 
 
1387
        self.old_transport = self._transport(self.old_server.get_url())
 
1388
 
 
1389
    def get_a(self, transport):
 
1390
        return transport.get('a')
 
1391
 
 
1392
    def test_no_redirection(self):
 
1393
        t = self._transport(self.new_server.get_url())
 
1394
 
 
1395
        # We use None for redirected so that we fail if redirected
 
1396
        self.assertEquals('0123456789',
 
1397
                          transport.do_catching_redirections(
 
1398
                self.get_a, t, None).read())
 
1399
 
 
1400
    def test_one_redirection(self):
 
1401
        self.redirections = 0
 
1402
 
 
1403
        def redirected(transport, exception, redirection_notice):
 
1404
            self.redirections += 1
 
1405
            dir, file = urlutils.split(exception.target)
 
1406
            return self._transport(dir)
 
1407
 
 
1408
        self.assertEquals('0123456789',
 
1409
                          transport.do_catching_redirections(
 
1410
                self.get_a, self.old_transport, redirected).read())
 
1411
        self.assertEquals(1, self.redirections)
 
1412
 
 
1413
    def test_redirection_loop(self):
 
1414
 
 
1415
        def redirected(transport, exception, redirection_notice):
 
1416
            # By using the redirected url as a base dir for the
 
1417
            # *old* transport, we create a loop: a => a/a =>
 
1418
            # a/a/a
 
1419
            return self.old_transport.clone(exception.target)
 
1420
 
 
1421
        self.assertRaises(errors.TooManyRedirections,
 
1422
                          transport.do_catching_redirections,
 
1423
                          self.get_a, self.old_transport, redirected)
 
1424
 
 
1425
 
 
1426
class TestAuth(http_utils.TestCaseWithWebserver):
 
1427
    """Test authentication scheme"""
 
1428
 
 
1429
    _auth_header = 'Authorization'
 
1430
    _password_prompt_prefix = ''
 
1431
 
 
1432
    def setUp(self):
 
1433
        super(TestAuth, self).setUp()
 
1434
        self.server = self.get_readonly_server()
 
1435
        self.build_tree_contents([('a', 'contents of a\n'),
 
1436
                                  ('b', 'contents of b\n'),])
 
1437
 
 
1438
    def create_transport_readonly_server(self):
 
1439
        if self._auth_scheme == 'basic':
 
1440
            server = http_utils.HTTPBasicAuthServer(
 
1441
                protocol_version=self._protocol_version)
 
1442
        else:
 
1443
            if self._auth_scheme != 'digest':
 
1444
                raise AssertionError('Unknown auth scheme: %r'
 
1445
                                     % self._auth_scheme)
 
1446
            server = http_utils.HTTPDigestAuthServer(
 
1447
                protocol_version=self._protocol_version)
 
1448
        return server
 
1449
 
 
1450
    def _testing_pycurl(self):
 
1451
        return pycurl_present and self._transport == PyCurlTransport
 
1452
 
 
1453
    def get_user_url(self, user=None, password=None):
 
1454
        """Build an url embedding user and password"""
 
1455
        url = '%s://' % self.server._url_protocol
 
1456
        if user is not None:
 
1457
            url += user
 
1458
            if password is not None:
 
1459
                url += ':' + password
 
1460
            url += '@'
 
1461
        url += '%s:%s/' % (self.server.host, self.server.port)
 
1462
        return url
 
1463
 
 
1464
    def get_user_transport(self, user=None, password=None):
 
1465
        return self._transport(self.get_user_url(user, password))
 
1466
 
 
1467
    def test_no_user(self):
 
1468
        self.server.add_user('joe', 'foo')
 
1469
        t = self.get_user_transport()
 
1470
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1471
        # Only one 'Authentication Required' error should occur
 
1472
        self.assertEqual(1, self.server.auth_required_errors)
 
1473
 
 
1474
    def test_empty_pass(self):
 
1475
        self.server.add_user('joe', '')
 
1476
        t = self.get_user_transport('joe', '')
 
1477
        self.assertEqual('contents of a\n', t.get('a').read())
 
1478
        # Only one 'Authentication Required' error should occur
 
1479
        self.assertEqual(1, self.server.auth_required_errors)
 
1480
 
 
1481
    def test_user_pass(self):
 
1482
        self.server.add_user('joe', 'foo')
 
1483
        t = self.get_user_transport('joe', 'foo')
 
1484
        self.assertEqual('contents of a\n', t.get('a').read())
 
1485
        # Only one 'Authentication Required' error should occur
 
1486
        self.assertEqual(1, self.server.auth_required_errors)
 
1487
 
 
1488
    def test_unknown_user(self):
 
1489
        self.server.add_user('joe', 'foo')
 
1490
        t = self.get_user_transport('bill', 'foo')
 
1491
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1492
        # Two 'Authentication Required' errors should occur (the
 
1493
        # initial 'who are you' and 'I don't know you, who are
 
1494
        # you').
 
1495
        self.assertEqual(2, self.server.auth_required_errors)
 
1496
 
 
1497
    def test_wrong_pass(self):
 
1498
        self.server.add_user('joe', 'foo')
 
1499
        t = self.get_user_transport('joe', 'bar')
 
1500
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1501
        # Two 'Authentication Required' errors should occur (the
 
1502
        # initial 'who are you' and 'this is not you, who are you')
 
1503
        self.assertEqual(2, self.server.auth_required_errors)
 
1504
 
 
1505
    def test_prompt_for_password(self):
 
1506
        if self._testing_pycurl():
 
1507
            raise tests.TestNotApplicable(
 
1508
                'pycurl cannot prompt, it handles auth by embedding'
 
1509
                ' user:pass in urls only')
 
1510
 
 
1511
        self.server.add_user('joe', 'foo')
 
1512
        t = self.get_user_transport('joe', None)
 
1513
        stdout = tests.StringIOWrapper()
 
1514
        ui.ui_factory = tests.TestUIFactory(stdin='foo\n', stdout=stdout)
 
1515
        self.assertEqual('contents of a\n',t.get('a').read())
 
1516
        # stdin should be empty
 
1517
        self.assertEqual('', ui.ui_factory.stdin.readline())
 
1518
        self._check_password_prompt(t._unqualified_scheme, 'joe',
 
1519
                                    stdout.getvalue())
 
1520
        # And we shouldn't prompt again for a different request
 
1521
        # against the same transport.
 
1522
        self.assertEqual('contents of b\n',t.get('b').read())
 
1523
        t2 = t.clone()
 
1524
        # And neither against a clone
 
1525
        self.assertEqual('contents of b\n',t2.get('b').read())
 
1526
        # Only one 'Authentication Required' error should occur
 
1527
        self.assertEqual(1, self.server.auth_required_errors)
 
1528
 
 
1529
    def _check_password_prompt(self, scheme, user, actual_prompt):
 
1530
        expected_prompt = (self._password_prompt_prefix
 
1531
                           + ("%s %s@%s:%d, Realm: '%s' password: "
 
1532
                              % (scheme.upper(),
 
1533
                                 user, self.server.host, self.server.port,
 
1534
                                 self.server.auth_realm)))
 
1535
        self.assertEquals(expected_prompt, actual_prompt)
 
1536
 
 
1537
    def test_no_prompt_for_password_when_using_auth_config(self):
 
1538
        if self._testing_pycurl():
 
1539
            raise tests.TestNotApplicable(
 
1540
                'pycurl does not support authentication.conf'
 
1541
                ' since it cannot prompt')
 
1542
 
 
1543
        user =' joe'
 
1544
        password = 'foo'
 
1545
        stdin_content = 'bar\n'  # Not the right password
 
1546
        self.server.add_user(user, password)
 
1547
        t = self.get_user_transport(user, None)
 
1548
        ui.ui_factory = tests.TestUIFactory(stdin=stdin_content,
 
1549
                                            stdout=tests.StringIOWrapper())
 
1550
        # Create a minimal config file with the right password
 
1551
        conf = config.AuthenticationConfig()
 
1552
        conf._get_config().update(
 
1553
            {'httptest': {'scheme': 'http', 'port': self.server.port,
 
1554
                          'user': user, 'password': password}})
 
1555
        conf._save()
 
1556
        # Issue a request to the server to connect
 
1557
        self.assertEqual('contents of a\n',t.get('a').read())
 
1558
        # stdin should have  been left untouched
 
1559
        self.assertEqual(stdin_content, ui.ui_factory.stdin.readline())
 
1560
        # Only one 'Authentication Required' error should occur
 
1561
        self.assertEqual(1, self.server.auth_required_errors)
 
1562
 
 
1563
    def test_changing_nonce(self):
 
1564
        if self._auth_scheme != 'digest':
 
1565
            raise tests.TestNotApplicable('HTTP auth digest only test')
 
1566
        if self._testing_pycurl():
 
1567
            raise tests.KnownFailure(
 
1568
                'pycurl does not handle a nonce change')
 
1569
        self.server.add_user('joe', 'foo')
 
1570
        t = self.get_user_transport('joe', 'foo')
 
1571
        self.assertEqual('contents of a\n', t.get('a').read())
 
1572
        self.assertEqual('contents of b\n', t.get('b').read())
 
1573
        # Only one 'Authentication Required' error should have
 
1574
        # occured so far
 
1575
        self.assertEqual(1, self.server.auth_required_errors)
 
1576
        # The server invalidates the current nonce
 
1577
        self.server.auth_nonce = self.server.auth_nonce + '. No, now!'
 
1578
        self.assertEqual('contents of a\n', t.get('a').read())
 
1579
        # Two 'Authentication Required' errors should occur (the
 
1580
        # initial 'who are you' and a second 'who are you' with the new nonce)
 
1581
        self.assertEqual(2, self.server.auth_required_errors)
 
1582
 
 
1583
 
 
1584
 
 
1585
class TestProxyAuth(TestAuth):
 
1586
    """Test proxy authentication schemes."""
 
1587
 
 
1588
    _auth_header = 'Proxy-authorization'
 
1589
    _password_prompt_prefix='Proxy '
 
1590
 
 
1591
    def setUp(self):
 
1592
        super(TestProxyAuth, self).setUp()
 
1593
        self._old_env = {}
 
1594
        self.addCleanup(self._restore_env)
 
1595
        # Override the contents to avoid false positives
 
1596
        self.build_tree_contents([('a', 'not proxied contents of a\n'),
 
1597
                                  ('b', 'not proxied contents of b\n'),
 
1598
                                  ('a-proxied', 'contents of a\n'),
 
1599
                                  ('b-proxied', 'contents of b\n'),
 
1600
                                  ])
 
1601
 
 
1602
    def create_transport_readonly_server(self):
 
1603
        if self._auth_scheme == 'basic':
 
1604
            server = http_utils.ProxyBasicAuthServer(
 
1605
                protocol_version=self._protocol_version)
 
1606
        else:
 
1607
            if self._auth_scheme != 'digest':
 
1608
                raise AssertionError('Unknown auth scheme: %r'
 
1609
                                     % self._auth_scheme)
 
1610
            server = http_utils.ProxyDigestAuthServer(
 
1611
                protocol_version=self._protocol_version)
 
1612
        return server
 
1613
 
 
1614
    def get_user_transport(self, user=None, password=None):
 
1615
        self._install_env({'all_proxy': self.get_user_url(user, password)})
 
1616
        return self._transport(self.server.get_url())
 
1617
 
 
1618
    def _install_env(self, env):
 
1619
        for name, value in env.iteritems():
 
1620
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1621
 
 
1622
    def _restore_env(self):
 
1623
        for name, value in self._old_env.iteritems():
 
1624
            osutils.set_or_unset_env(name, value)
 
1625
 
 
1626
    def test_empty_pass(self):
 
1627
        if self._testing_pycurl():
 
1628
            import pycurl
 
1629
            if pycurl.version_info()[1] < '7.16.0':
 
1630
                raise tests.KnownFailure(
 
1631
                    'pycurl < 7.16.0 does not handle empty proxy passwords')
 
1632
        super(TestProxyAuth, self).test_empty_pass()
 
1633
 
 
1634
 
 
1635
class SampleSocket(object):
 
1636
    """A socket-like object for use in testing the HTTP request handler."""
 
1637
 
 
1638
    def __init__(self, socket_read_content):
 
1639
        """Constructs a sample socket.
 
1640
 
 
1641
        :param socket_read_content: a byte sequence
 
1642
        """
 
1643
        # Use plain python StringIO so we can monkey-patch the close method to
 
1644
        # not discard the contents.
 
1645
        from StringIO import StringIO
 
1646
        self.readfile = StringIO(socket_read_content)
 
1647
        self.writefile = StringIO()
 
1648
        self.writefile.close = lambda: None
 
1649
 
 
1650
    def makefile(self, mode='r', bufsize=None):
 
1651
        if 'r' in mode:
 
1652
            return self.readfile
 
1653
        else:
 
1654
            return self.writefile
 
1655
 
 
1656
 
 
1657
class SmartHTTPTunnellingTest(tests.TestCaseWithTransport):
 
1658
 
 
1659
    def setUp(self):
 
1660
        super(SmartHTTPTunnellingTest, self).setUp()
 
1661
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1662
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1663
        self.transport_readonly_server = http_utils.HTTPServerWithSmarts
 
1664
 
 
1665
    def create_transport_readonly_server(self):
 
1666
        return http_utils.HTTPServerWithSmarts(
 
1667
            protocol_version=self._protocol_version)
 
1668
 
 
1669
    def test_bulk_data(self):
 
1670
        # We should be able to send and receive bulk data in a single message.
 
1671
        # The 'readv' command in the smart protocol both sends and receives
 
1672
        # bulk data, so we use that.
 
1673
        self.build_tree(['data-file'])
 
1674
        http_server = self.get_readonly_server()
 
1675
        http_transport = self._transport(http_server.get_url())
 
1676
        medium = http_transport.get_smart_medium()
 
1677
        # Since we provide the medium, the url below will be mostly ignored
 
1678
        # during the test, as long as the path is '/'.
 
1679
        remote_transport = remote.RemoteTransport('bzr://fake_host/',
 
1680
                                                  medium=medium)
 
1681
        self.assertEqual(
 
1682
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
 
1683
 
 
1684
    def test_http_send_smart_request(self):
 
1685
 
 
1686
        post_body = 'hello\n'
 
1687
        expected_reply_body = 'ok\x012\n'
 
1688
 
 
1689
        http_server = self.get_readonly_server()
 
1690
        http_transport = self._transport(http_server.get_url())
 
1691
        medium = http_transport.get_smart_medium()
 
1692
        response = medium.send_http_smart_request(post_body)
 
1693
        reply_body = response.read()
 
1694
        self.assertEqual(expected_reply_body, reply_body)
 
1695
 
 
1696
    def test_smart_http_server_post_request_handler(self):
 
1697
        httpd = self.get_readonly_server()._get_httpd()
 
1698
 
 
1699
        socket = SampleSocket(
 
1700
            'POST /.bzr/smart %s \r\n' % self._protocol_version
 
1701
            # HTTP/1.1 posts must have a Content-Length (but it doesn't hurt
 
1702
            # for 1.0)
 
1703
            + 'Content-Length: 6\r\n'
 
1704
            '\r\n'
 
1705
            'hello\n')
 
1706
        # Beware: the ('localhost', 80) below is the
 
1707
        # client_address parameter, but we don't have one because
 
1708
        # we have defined a socket which is not bound to an
 
1709
        # address. The test framework never uses this client
 
1710
        # address, so far...
 
1711
        request_handler = http_utils.SmartRequestHandler(socket,
 
1712
                                                         ('localhost', 80),
 
1713
                                                         httpd)
 
1714
        response = socket.writefile.getvalue()
 
1715
        self.assertStartsWith(response, '%s 200 ' % self._protocol_version)
 
1716
        # This includes the end of the HTTP headers, and all the body.
 
1717
        expected_end_of_response = '\r\n\r\nok\x012\n'
 
1718
        self.assertEndsWith(response, expected_end_of_response)
 
1719
 
 
1720