/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/tests/test_server.py

[merge] robertc's integration, updated tests to check for retcode=3

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010, 2011 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
import errno
18
 
import socket
19
 
try:
20
 
    import socketserver
21
 
except ImportError:
22
 
    import SocketServer as socketserver
23
 
import sys
24
 
import threading
25
 
 
26
 
 
27
 
from breezy import (
28
 
    cethread,
29
 
    errors,
30
 
    osutils,
31
 
    transport,
32
 
    urlutils,
33
 
    )
34
 
from breezy.transport import (
35
 
    chroot,
36
 
    pathfilter,
37
 
    )
38
 
from breezy.bzr.smart import (
39
 
    medium,
40
 
    server,
41
 
    )
42
 
 
43
 
 
44
 
def debug_threads():
45
 
    # FIXME: There is a dependency loop between breezy.tests and
46
 
    # breezy.tests.test_server that needs to be fixed. In the mean time
47
 
    # defining this function is enough for our needs. -- vila 20100611
48
 
    from breezy import tests
49
 
    return 'threads' in tests.selftest_debug_flags
50
 
 
51
 
 
52
 
class TestServer(transport.Server):
53
 
    """A Transport Server dedicated to tests.
54
 
 
55
 
    The TestServer interface provides a server for a given transport. We use
56
 
    these servers as loopback testing tools. For any given transport the
57
 
    Servers it provides must either allow writing, or serve the contents
58
 
    of osutils.getcwd() at the time start_server is called.
59
 
 
60
 
    Note that these are real servers - they must implement all the things
61
 
    that we want bzr transports to take advantage of.
62
 
    """
63
 
 
64
 
    def get_url(self):
65
 
        """Return a url for this server.
66
 
 
67
 
        If the transport does not represent a disk directory (i.e. it is
68
 
        a database like svn, or a memory only transport, it should return
69
 
        a connection to a newly established resource for this Server.
70
 
        Otherwise it should return a url that will provide access to the path
71
 
        that was osutils.getcwd() when start_server() was called.
72
 
 
73
 
        Subsequent calls will return the same resource.
74
 
        """
75
 
        raise NotImplementedError
76
 
 
77
 
    def get_bogus_url(self):
78
 
        """Return a url for this protocol, that will fail to connect.
79
 
 
80
 
        This may raise NotImplementedError to indicate that this server cannot
81
 
        provide bogus urls.
82
 
        """
83
 
        raise NotImplementedError
84
 
 
85
 
 
86
 
class LocalURLServer(TestServer):
87
 
    """A pretend server for local transports, using file:// urls.
88
 
 
89
 
    Of course no actual server is required to access the local filesystem, so
90
 
    this just exists to tell the test code how to get to it.
91
 
    """
92
 
 
93
 
    def start_server(self):
94
 
        pass
95
 
 
96
 
    def get_url(self):
97
 
        """See Transport.Server.get_url."""
98
 
        return urlutils.local_path_to_url('')
99
 
 
100
 
 
101
 
class DecoratorServer(TestServer):
102
 
    """Server for the TransportDecorator for testing with.
103
 
 
104
 
    To use this when subclassing TransportDecorator, override override the
105
 
    get_decorator_class method.
106
 
    """
107
 
 
108
 
    def start_server(self, server=None):
109
 
        """See breezy.transport.Server.start_server.
110
 
 
111
 
        :server: decorate the urls given by server. If not provided a
112
 
        LocalServer is created.
113
 
        """
114
 
        if server is not None:
115
 
            self._made_server = False
116
 
            self._server = server
117
 
        else:
118
 
            self._made_server = True
119
 
            self._server = LocalURLServer()
120
 
            self._server.start_server()
121
 
 
122
 
    def stop_server(self):
123
 
        if self._made_server:
124
 
            self._server.stop_server()
125
 
 
126
 
    def get_decorator_class(self):
127
 
        """Return the class of the decorators we should be constructing."""
128
 
        raise NotImplementedError(self.get_decorator_class)
129
 
 
130
 
    def get_url_prefix(self):
131
 
        """What URL prefix does this decorator produce?"""
132
 
        return self.get_decorator_class()._get_url_prefix()
133
 
 
134
 
    def get_bogus_url(self):
135
 
        """See breezy.transport.Server.get_bogus_url."""
136
 
        return self.get_url_prefix() + self._server.get_bogus_url()
137
 
 
138
 
    def get_url(self):
139
 
        """See breezy.transport.Server.get_url."""
140
 
        return self.get_url_prefix() + self._server.get_url()
141
 
 
142
 
 
143
 
class BrokenRenameServer(DecoratorServer):
144
 
    """Server for the BrokenRenameTransportDecorator for testing with."""
145
 
 
146
 
    def get_decorator_class(self):
147
 
        from breezy.transport import brokenrename
148
 
        return brokenrename.BrokenRenameTransportDecorator
149
 
 
150
 
 
151
 
class FakeNFSServer(DecoratorServer):
152
 
    """Server for the FakeNFSTransportDecorator for testing with."""
153
 
 
154
 
    def get_decorator_class(self):
155
 
        from breezy.transport import fakenfs
156
 
        return fakenfs.FakeNFSTransportDecorator
157
 
 
158
 
 
159
 
class FakeVFATServer(DecoratorServer):
160
 
    """A server that suggests connections through FakeVFATTransportDecorator
161
 
 
162
 
    For use in testing.
163
 
    """
164
 
 
165
 
    def get_decorator_class(self):
166
 
        from breezy.transport import fakevfat
167
 
        return fakevfat.FakeVFATTransportDecorator
168
 
 
169
 
 
170
 
class LogDecoratorServer(DecoratorServer):
171
 
    """Server for testing."""
172
 
 
173
 
    def get_decorator_class(self):
174
 
        from breezy.transport import log
175
 
        return log.TransportLogDecorator
176
 
 
177
 
 
178
 
class NoSmartTransportServer(DecoratorServer):
179
 
    """Server for the NoSmartTransportDecorator for testing with."""
180
 
 
181
 
    def get_decorator_class(self):
182
 
        from breezy.transport import nosmart
183
 
        return nosmart.NoSmartTransportDecorator
184
 
 
185
 
 
186
 
class ReadonlyServer(DecoratorServer):
187
 
    """Server for the ReadonlyTransportDecorator for testing with."""
188
 
 
189
 
    def get_decorator_class(self):
190
 
        from breezy.transport import readonly
191
 
        return readonly.ReadonlyTransportDecorator
192
 
 
193
 
 
194
 
class TraceServer(DecoratorServer):
195
 
    """Server for the TransportTraceDecorator for testing with."""
196
 
 
197
 
    def get_decorator_class(self):
198
 
        from breezy.transport import trace
199
 
        return trace.TransportTraceDecorator
200
 
 
201
 
 
202
 
class UnlistableServer(DecoratorServer):
203
 
    """Server for the UnlistableTransportDecorator for testing with."""
204
 
 
205
 
    def get_decorator_class(self):
206
 
        from breezy.transport import unlistable
207
 
        return unlistable.UnlistableTransportDecorator
208
 
 
209
 
 
210
 
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
211
 
 
212
 
    def __init__(self):
213
 
        """TestingPathFilteringServer is not usable until start_server
214
 
        is called."""
215
 
 
216
 
    def start_server(self, backing_server=None):
217
 
        """Setup the Chroot on backing_server."""
218
 
        if backing_server is not None:
219
 
            self.backing_transport = transport.get_transport_from_url(
220
 
                backing_server.get_url())
221
 
        else:
222
 
            self.backing_transport = transport.get_transport_from_path('.')
223
 
        self.backing_transport.clone('added-by-filter').ensure_base()
224
 
        self.filter_func = lambda x: 'added-by-filter/' + x
225
 
        super(TestingPathFilteringServer, self).start_server()
226
 
 
227
 
    def get_bogus_url(self):
228
 
        raise NotImplementedError
229
 
 
230
 
 
231
 
class TestingChrootServer(chroot.ChrootServer):
232
 
 
233
 
    def __init__(self):
234
 
        """TestingChrootServer is not usable until start_server is called."""
235
 
        super(TestingChrootServer, self).__init__(None)
236
 
 
237
 
    def start_server(self, backing_server=None):
238
 
        """Setup the Chroot on backing_server."""
239
 
        if backing_server is not None:
240
 
            self.backing_transport = transport.get_transport_from_url(
241
 
                backing_server.get_url())
242
 
        else:
243
 
            self.backing_transport = transport.get_transport_from_path('.')
244
 
        super(TestingChrootServer, self).start_server()
245
 
 
246
 
    def get_bogus_url(self):
247
 
        raise NotImplementedError
248
 
 
249
 
 
250
 
class TestThread(cethread.CatchingExceptionThread):
251
 
 
252
 
    if not getattr(cethread.CatchingExceptionThread, 'is_alive', None):
253
 
        def is_alive(self):
254
 
            return self.isAlive()
255
 
 
256
 
    def join(self, timeout=5):
257
 
        """Overrides to use a default timeout.
258
 
 
259
 
        The default timeout is set to 5 and should expire only when a thread
260
 
        serving a client connection is hung.
261
 
        """
262
 
        super(TestThread, self).join(timeout)
263
 
        if timeout and self.is_alive():
264
 
            # The timeout expired without joining the thread, the thread is
265
 
            # therefore stucked and that's a failure as far as the test is
266
 
            # concerned. We used to hang here.
267
 
 
268
 
            # FIXME: we need to kill the thread, but as far as the test is
269
 
            # concerned, raising an assertion is too strong. On most of the
270
 
            # platforms, this doesn't occur, so just mentioning the problem is
271
 
            # enough for now -- vila 2010824
272
 
            sys.stderr.write('thread %s hung\n' % (self.name,))
273
 
            # raise AssertionError('thread %s hung' % (self.name,))
274
 
 
275
 
 
276
 
class TestingTCPServerMixin(object):
277
 
    """Mixin to support running socketserver.TCPServer in a thread.
278
 
 
279
 
    Tests are connecting from the main thread, the server has to be run in a
280
 
    separate thread.
281
 
    """
282
 
 
283
 
    def __init__(self):
284
 
        self.started = threading.Event()
285
 
        self.serving = None
286
 
        self.stopped = threading.Event()
287
 
        # We collect the resources used by the clients so we can release them
288
 
        # when shutting down
289
 
        self.clients = []
290
 
        self.ignored_exceptions = None
291
 
 
292
 
    def server_bind(self):
293
 
        self.socket.bind(self.server_address)
294
 
        self.server_address = self.socket.getsockname()
295
 
 
296
 
    def serve(self):
297
 
        self.serving = True
298
 
        # We are listening and ready to accept connections
299
 
        self.started.set()
300
 
        try:
301
 
            while self.serving:
302
 
                # Really a connection but the python framework is generic and
303
 
                # call them requests
304
 
                self.handle_request()
305
 
            # Let's close the listening socket
306
 
            self.server_close()
307
 
        finally:
308
 
            self.stopped.set()
309
 
 
310
 
    def handle_request(self):
311
 
        """Handle one request.
312
 
 
313
 
        The python version swallows some socket exceptions and we don't use
314
 
        timeout, so we override it to better control the server behavior.
315
 
        """
316
 
        request, client_address = self.get_request()
317
 
        if self.verify_request(request, client_address):
318
 
            try:
319
 
                self.process_request(request, client_address)
320
 
            except BaseException:
321
 
                self.handle_error(request, client_address)
322
 
        else:
323
 
            self.close_request(request)
324
 
 
325
 
    def get_request(self):
326
 
        return self.socket.accept()
327
 
 
328
 
    def verify_request(self, request, client_address):
329
 
        """Verify the request.
330
 
 
331
 
        Return True if we should proceed with this request, False if we should
332
 
        not even touch a single byte in the socket ! This is useful when we
333
 
        stop the server with a dummy last connection.
334
 
        """
335
 
        return self.serving
336
 
 
337
 
    def handle_error(self, request, client_address):
338
 
        # Stop serving and re-raise the last exception seen
339
 
        self.serving = False
340
 
        # The following can be used for debugging purposes, it will display the
341
 
        # exception and the traceback just when it occurs instead of waiting
342
 
        # for the thread to be joined.
343
 
        # socketserver.BaseServer.handle_error(self, request, client_address)
344
 
 
345
 
        # We call close_request manually, because we are going to raise an
346
 
        # exception. The socketserver implementation calls:
347
 
        #   handle_error(...)
348
 
        #   close_request(...)
349
 
        # But because we raise the exception, close_request will never be
350
 
        # triggered. This helps client not block waiting for a response when
351
 
        # the server gets an exception.
352
 
        self.close_request(request)
353
 
        raise
354
 
 
355
 
    def ignored_exceptions_during_shutdown(self, e):
356
 
        if sys.platform == 'win32':
357
 
            accepted_errnos = [errno.EBADF,
358
 
                               errno.EPIPE,
359
 
                               errno.WSAEBADF,
360
 
                               errno.WSAECONNRESET,
361
 
                               errno.WSAENOTCONN,
362
 
                               errno.WSAESHUTDOWN,
363
 
                               ]
364
 
        else:
365
 
            accepted_errnos = [errno.EBADF,
366
 
                               errno.ECONNRESET,
367
 
                               errno.ENOTCONN,
368
 
                               errno.EPIPE,
369
 
                               ]
370
 
        if isinstance(e, socket.error) and e.errno in accepted_errnos:
371
 
            return True
372
 
        return False
373
 
 
374
 
    # The following methods are called by the main thread
375
 
 
376
 
    def stop_client_connections(self):
377
 
        while self.clients:
378
 
            c = self.clients.pop()
379
 
            self.shutdown_client(c)
380
 
 
381
 
    def shutdown_socket(self, sock):
382
 
        """Properly shutdown a socket.
383
 
 
384
 
        This should be called only when no other thread is trying to use the
385
 
        socket.
386
 
        """
387
 
        try:
388
 
            sock.shutdown(socket.SHUT_RDWR)
389
 
            sock.close()
390
 
        except Exception as e:
391
 
            if self.ignored_exceptions(e):
392
 
                pass
393
 
            else:
394
 
                raise
395
 
 
396
 
    # The following methods are called by the main thread
397
 
 
398
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
399
 
        self.ignored_exceptions = ignored_exceptions
400
 
        thread.set_ignored_exceptions(self.ignored_exceptions)
401
 
 
402
 
    def _pending_exception(self, thread):
403
 
        """Raise server uncaught exception.
404
 
 
405
 
        Daughter classes can override this if they use daughter threads.
406
 
        """
407
 
        thread.pending_exception()
408
 
 
409
 
 
410
 
class TestingTCPServer(TestingTCPServerMixin, socketserver.TCPServer):
411
 
 
412
 
    def __init__(self, server_address, request_handler_class):
413
 
        TestingTCPServerMixin.__init__(self)
414
 
        socketserver.TCPServer.__init__(self, server_address,
415
 
                                        request_handler_class)
416
 
 
417
 
    def get_request(self):
418
 
        """Get the request and client address from the socket."""
419
 
        sock, addr = TestingTCPServerMixin.get_request(self)
420
 
        self.clients.append((sock, addr))
421
 
        return sock, addr
422
 
 
423
 
    # The following methods are called by the main thread
424
 
 
425
 
    def shutdown_client(self, client):
426
 
        sock, addr = client
427
 
        self.shutdown_socket(sock)
428
 
 
429
 
 
430
 
class TestingThreadingTCPServer(TestingTCPServerMixin,
431
 
                                socketserver.ThreadingTCPServer):
432
 
 
433
 
    def __init__(self, server_address, request_handler_class):
434
 
        TestingTCPServerMixin.__init__(self)
435
 
        socketserver.ThreadingTCPServer.__init__(self, server_address,
436
 
                                                 request_handler_class)
437
 
 
438
 
    def get_request(self):
439
 
        """Get the request and client address from the socket."""
440
 
        sock, addr = TestingTCPServerMixin.get_request(self)
441
 
        # The thread is not created yet, it will be updated in process_request
442
 
        self.clients.append((sock, addr, None))
443
 
        return sock, addr
444
 
 
445
 
    def process_request_thread(self, started, detached, stopped,
446
 
                               request, client_address):
447
 
        started.set()
448
 
        # We will be on our own once the server tells us we're detached
449
 
        detached.wait()
450
 
        socketserver.ThreadingTCPServer.process_request_thread(
451
 
            self, request, client_address)
452
 
        self.close_request(request)
453
 
        stopped.set()
454
 
 
455
 
    def process_request(self, request, client_address):
456
 
        """Start a new thread to process the request."""
457
 
        started = threading.Event()
458
 
        detached = threading.Event()
459
 
        stopped = threading.Event()
460
 
        t = TestThread(
461
 
            sync_event=stopped,
462
 
            name='%s -> %s' % (client_address, self.server_address),
463
 
            target=self.process_request_thread,
464
 
            args=(started, detached, stopped, request, client_address))
465
 
        # Update the client description
466
 
        self.clients.pop()
467
 
        self.clients.append((request, client_address, t))
468
 
        # Propagate the exception handler since we must use the same one as
469
 
        # TestingTCPServer for connections running in their own threads.
470
 
        t.set_ignored_exceptions(self.ignored_exceptions)
471
 
        t.start()
472
 
        started.wait()
473
 
        # If an exception occured during the thread start, it will get raised.
474
 
        t.pending_exception()
475
 
        if debug_threads():
476
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
477
 
        # Tell the thread, it's now on its own for exception handling.
478
 
        detached.set()
479
 
 
480
 
    # The following methods are called by the main thread
481
 
 
482
 
    def shutdown_client(self, client):
483
 
        sock, addr, connection_thread = client
484
 
        self.shutdown_socket(sock)
485
 
        if connection_thread is not None:
486
 
            # The thread has been created only if the request is processed but
487
 
            # after the connection is inited. This could happen during server
488
 
            # shutdown. If an exception occurred in the thread it will be
489
 
            # re-raised
490
 
            if debug_threads():
491
 
                sys.stderr.write('Client thread %s will be joined\n'
492
 
                                 % (connection_thread.name,))
493
 
            connection_thread.join()
494
 
 
495
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
496
 
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
497
 
                                                     ignored_exceptions)
498
 
        for sock, addr, connection_thread in self.clients:
499
 
            if connection_thread is not None:
500
 
                connection_thread.set_ignored_exceptions(
501
 
                    self.ignored_exceptions)
502
 
 
503
 
    def _pending_exception(self, thread):
504
 
        for sock, addr, connection_thread in self.clients:
505
 
            if connection_thread is not None:
506
 
                connection_thread.pending_exception()
507
 
        TestingTCPServerMixin._pending_exception(self, thread)
508
 
 
509
 
 
510
 
class TestingTCPServerInAThread(transport.Server):
511
 
    """A server in a thread that re-raise thread exceptions."""
512
 
 
513
 
    def __init__(self, server_address, server_class, request_handler_class):
514
 
        self.server_class = server_class
515
 
        self.request_handler_class = request_handler_class
516
 
        self.host, self.port = server_address
517
 
        self.server = None
518
 
        self._server_thread = None
519
 
 
520
 
    def __repr__(self):
521
 
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
522
 
 
523
 
    def create_server(self):
524
 
        return self.server_class((self.host, self.port),
525
 
                                 self.request_handler_class)
526
 
 
527
 
    def start_server(self):
528
 
        self.server = self.create_server()
529
 
        self._server_thread = TestThread(
530
 
            sync_event=self.server.started,
531
 
            target=self.run_server)
532
 
        self._server_thread.start()
533
 
        # Wait for the server thread to start (i.e. release the lock)
534
 
        self.server.started.wait()
535
 
        # Get the real address, especially the port
536
 
        self.host, self.port = self.server.server_address
537
 
        self._server_thread.name = self.server.server_address
538
 
        if debug_threads():
539
 
            sys.stderr.write('Server thread %s started\n'
540
 
                             % (self._server_thread.name,))
541
 
        # If an exception occured during the server start, it will get raised,
542
 
        # otherwise, the server is blocked on its accept() call.
543
 
        self._server_thread.pending_exception()
544
 
        # From now on, we'll use a different event to ensure the server can set
545
 
        # its exception
546
 
        self._server_thread.set_sync_event(self.server.stopped)
547
 
 
548
 
    def run_server(self):
549
 
        self.server.serve()
550
 
 
551
 
    def stop_server(self):
552
 
        if self.server is None:
553
 
            return
554
 
        try:
555
 
            # The server has been started successfully, shut it down now.  As
556
 
            # soon as we stop serving, no more connection are accepted except
557
 
            # one to get out of the blocking listen.
558
 
            self.set_ignored_exceptions(
559
 
                self.server.ignored_exceptions_during_shutdown)
560
 
            self.server.serving = False
561
 
            if debug_threads():
562
 
                sys.stderr.write('Server thread %s will be joined\n'
563
 
                                 % (self._server_thread.name,))
564
 
            # The server is listening for a last connection, let's give it:
565
 
            last_conn = None
566
 
            try:
567
 
                last_conn = osutils.connect_socket((self.host, self.port))
568
 
            except socket.error:
569
 
                # But ignore connection errors as the point is to unblock the
570
 
                # server thread, it may happen that it's not blocked or even
571
 
                # not started.
572
 
                pass
573
 
            # We start shutting down the clients while the server itself is
574
 
            # shutting down.
575
 
            self.server.stop_client_connections()
576
 
            # Now we wait for the thread running self.server.serve() to finish
577
 
            self.server.stopped.wait()
578
 
            if last_conn is not None:
579
 
                # Close the last connection without trying to use it. The
580
 
                # server will not process a single byte on that socket to avoid
581
 
                # complications (SSL starts with a handshake for example).
582
 
                last_conn.close()
583
 
            # Check for any exception that could have occurred in the server
584
 
            # thread
585
 
            try:
586
 
                self._server_thread.join()
587
 
            except Exception as e:
588
 
                if self.server.ignored_exceptions(e):
589
 
                    pass
590
 
                else:
591
 
                    raise
592
 
        finally:
593
 
            # Make sure we can be called twice safely, note that this means
594
 
            # that we will raise a single exception even if several occurred in
595
 
            # the various threads involved.
596
 
            self.server = None
597
 
 
598
 
    def set_ignored_exceptions(self, ignored_exceptions):
599
 
        """Install an exception handler for the server."""
600
 
        self.server.set_ignored_exceptions(self._server_thread,
601
 
                                           ignored_exceptions)
602
 
 
603
 
    def pending_exception(self):
604
 
        """Raise uncaught exception in the server."""
605
 
        self.server._pending_exception(self._server_thread)
606
 
 
607
 
 
608
 
class TestingSmartConnectionHandler(socketserver.BaseRequestHandler,
609
 
                                    medium.SmartServerSocketStreamMedium):
610
 
 
611
 
    def __init__(self, request, client_address, server):
612
 
        medium.SmartServerSocketStreamMedium.__init__(
613
 
            self, request, server.backing_transport,
614
 
            server.root_client_path,
615
 
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
616
 
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
617
 
        socketserver.BaseRequestHandler.__init__(self, request, client_address,
618
 
                                                 server)
619
 
 
620
 
    def handle(self):
621
 
        try:
622
 
            while not self.finished:
623
 
                server_protocol = self._build_protocol()
624
 
                self._serve_one_request(server_protocol)
625
 
        except errors.ConnectionTimeout:
626
 
            # idle connections aren't considered a failure of the server
627
 
            return
628
 
 
629
 
 
630
 
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
631
 
 
632
 
 
633
 
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
634
 
 
635
 
    def __init__(self, server_address, request_handler_class,
636
 
                 backing_transport, root_client_path):
637
 
        TestingThreadingTCPServer.__init__(self, server_address,
638
 
                                           request_handler_class)
639
 
        server.SmartTCPServer.__init__(
640
 
            self, backing_transport,
641
 
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
642
 
 
643
 
    def serve(self):
644
 
        self.run_server_started_hooks()
645
 
        try:
646
 
            TestingThreadingTCPServer.serve(self)
647
 
        finally:
648
 
            self.run_server_stopped_hooks()
649
 
 
650
 
    def get_url(self):
651
 
        """Return the url of the server"""
652
 
        return "bzr://%s:%d/" % self.server_address
653
 
 
654
 
 
655
 
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
656
 
    """Server suitable for use by transport tests.
657
 
 
658
 
    This server is backed by the process's cwd.
659
 
    """
660
 
 
661
 
    def __init__(self, thread_name_suffix=''):
662
 
        self.client_path_extra = None
663
 
        self.thread_name_suffix = thread_name_suffix
664
 
        self.host = '127.0.0.1'
665
 
        self.port = 0
666
 
        super(SmartTCPServer_for_testing, self).__init__(
667
 
            (self.host, self.port),
668
 
            TestingSmartServer,
669
 
            TestingSmartConnectionHandler)
670
 
 
671
 
    def create_server(self):
672
 
        return self.server_class((self.host, self.port),
673
 
                                 self.request_handler_class,
674
 
                                 self.backing_transport,
675
 
                                 self.root_client_path)
676
 
 
677
 
    def start_server(self, backing_transport_server=None,
678
 
                     client_path_extra='/extra/'):
679
 
        """Set up server for testing.
680
 
 
681
 
        :param backing_transport_server: backing server to use.  If not
682
 
            specified, a LocalURLServer at the current working directory will
683
 
            be used.
684
 
        :param client_path_extra: a path segment starting with '/' to append to
685
 
            the root URL for this server.  For instance, a value of '/foo/bar/'
686
 
            will mean the root of the backing transport will be published at a
687
 
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
688
 
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
689
 
            by default will fail unless they do the necessary path translation.
690
 
        """
691
 
        if not client_path_extra.startswith('/'):
692
 
            raise ValueError(client_path_extra)
693
 
        self.root_client_path = self.client_path_extra = client_path_extra
694
 
        from breezy.transport.chroot import ChrootServer
695
 
        if backing_transport_server is None:
696
 
            backing_transport_server = LocalURLServer()
697
 
        self.chroot_server = ChrootServer(
698
 
            self.get_backing_transport(backing_transport_server))
699
 
        self.chroot_server.start_server()
700
 
        self.backing_transport = transport.get_transport_from_url(
701
 
            self.chroot_server.get_url())
702
 
        super(SmartTCPServer_for_testing, self).start_server()
703
 
 
704
 
    def stop_server(self):
705
 
        try:
706
 
            super(SmartTCPServer_for_testing, self).stop_server()
707
 
        finally:
708
 
            self.chroot_server.stop_server()
709
 
 
710
 
    def get_backing_transport(self, backing_transport_server):
711
 
        """Get a backing transport from a server we are decorating."""
712
 
        return transport.get_transport_from_url(
713
 
            backing_transport_server.get_url())
714
 
 
715
 
    def get_url(self):
716
 
        url = self.server.get_url()
717
 
        return url[:-1] + self.client_path_extra
718
 
 
719
 
    def get_bogus_url(self):
720
 
        """Return a URL which will fail to connect"""
721
 
        return 'bzr://127.0.0.1:1/'
722
 
 
723
 
 
724
 
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
725
 
    """Get a readonly server for testing."""
726
 
 
727
 
    def get_backing_transport(self, backing_transport_server):
728
 
        """Get a backing transport from a server we are decorating."""
729
 
        url = 'readonly+' + backing_transport_server.get_url()
730
 
        return transport.get_transport_from_url(url)
731
 
 
732
 
 
733
 
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
734
 
    """A variation of SmartTCPServer_for_testing that limits the client to
735
 
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
736
 
    """
737
 
 
738
 
    def get_url(self):
739
 
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
740
 
        url = 'bzr-v2://' + url[len('bzr://'):]
741
 
        return url
742
 
 
743
 
 
744
 
class ReadonlySmartTCPServer_for_testing_v2_only(
745
 
        SmartTCPServer_for_testing_v2_only):
746
 
    """Get a readonly server for testing."""
747
 
 
748
 
    def get_backing_transport(self, backing_transport_server):
749
 
        """Get a backing transport from a server we are decorating."""
750
 
        url = 'readonly+' + backing_transport_server.get_url()
751
 
        return transport.get_transport_from_url(url)