/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/tests/test_server.py

  • Committer: Jelmer Vernooij
  • Date: 2020-05-24 00:39:50 UTC
  • mto: This revision was merged to the branch mainline in revision 7504.
  • Revision ID: jelmer@jelmer.uk-20200524003950-bbc545r76vc5yajg
Add github action.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import socketserver
 
20
import sys
 
21
import threading
 
22
 
 
23
 
 
24
from breezy import (
 
25
    cethread,
 
26
    errors,
 
27
    osutils,
 
28
    transport,
 
29
    urlutils,
 
30
    )
 
31
from breezy.transport import (
 
32
    chroot,
 
33
    pathfilter,
 
34
    )
 
35
from breezy.bzr.smart import (
 
36
    medium,
 
37
    server,
 
38
    )
 
39
 
 
40
 
 
41
def debug_threads():
 
42
    # FIXME: There is a dependency loop between breezy.tests and
 
43
    # breezy.tests.test_server that needs to be fixed. In the mean time
 
44
    # defining this function is enough for our needs. -- vila 20100611
 
45
    from breezy import tests
 
46
    return 'threads' in tests.selftest_debug_flags
 
47
 
 
48
 
 
49
class TestServer(transport.Server):
 
50
    """A Transport Server dedicated to tests.
 
51
 
 
52
    The TestServer interface provides a server for a given transport. We use
 
53
    these servers as loopback testing tools. For any given transport the
 
54
    Servers it provides must either allow writing, or serve the contents
 
55
    of osutils.getcwd() at the time start_server is called.
 
56
 
 
57
    Note that these are real servers - they must implement all the things
 
58
    that we want bzr transports to take advantage of.
 
59
    """
 
60
 
 
61
    def get_url(self):
 
62
        """Return a url for this server.
 
63
 
 
64
        If the transport does not represent a disk directory (i.e. it is
 
65
        a database like svn, or a memory only transport, it should return
 
66
        a connection to a newly established resource for this Server.
 
67
        Otherwise it should return a url that will provide access to the path
 
68
        that was osutils.getcwd() when start_server() was called.
 
69
 
 
70
        Subsequent calls will return the same resource.
 
71
        """
 
72
        raise NotImplementedError
 
73
 
 
74
    def get_bogus_url(self):
 
75
        """Return a url for this protocol, that will fail to connect.
 
76
 
 
77
        This may raise NotImplementedError to indicate that this server cannot
 
78
        provide bogus urls.
 
79
        """
 
80
        raise NotImplementedError
 
81
 
 
82
 
 
83
class LocalURLServer(TestServer):
 
84
    """A pretend server for local transports, using file:// urls.
 
85
 
 
86
    Of course no actual server is required to access the local filesystem, so
 
87
    this just exists to tell the test code how to get to it.
 
88
    """
 
89
 
 
90
    def start_server(self):
 
91
        pass
 
92
 
 
93
    def get_url(self):
 
94
        """See Transport.Server.get_url."""
 
95
        return urlutils.local_path_to_url('')
 
96
 
 
97
 
 
98
class DecoratorServer(TestServer):
 
99
    """Server for the TransportDecorator for testing with.
 
100
 
 
101
    To use this when subclassing TransportDecorator, override override the
 
102
    get_decorator_class method.
 
103
    """
 
104
 
 
105
    def start_server(self, server=None):
 
106
        """See breezy.transport.Server.start_server.
 
107
 
 
108
        :server: decorate the urls given by server. If not provided a
 
109
        LocalServer is created.
 
110
        """
 
111
        if server is not None:
 
112
            self._made_server = False
 
113
            self._server = server
 
114
        else:
 
115
            self._made_server = True
 
116
            self._server = LocalURLServer()
 
117
            self._server.start_server()
 
118
 
 
119
    def stop_server(self):
 
120
        if self._made_server:
 
121
            self._server.stop_server()
 
122
 
 
123
    def get_decorator_class(self):
 
124
        """Return the class of the decorators we should be constructing."""
 
125
        raise NotImplementedError(self.get_decorator_class)
 
126
 
 
127
    def get_url_prefix(self):
 
128
        """What URL prefix does this decorator produce?"""
 
129
        return self.get_decorator_class()._get_url_prefix()
 
130
 
 
131
    def get_bogus_url(self):
 
132
        """See breezy.transport.Server.get_bogus_url."""
 
133
        return self.get_url_prefix() + self._server.get_bogus_url()
 
134
 
 
135
    def get_url(self):
 
136
        """See breezy.transport.Server.get_url."""
 
137
        return self.get_url_prefix() + self._server.get_url()
 
138
 
 
139
 
 
140
class BrokenRenameServer(DecoratorServer):
 
141
    """Server for the BrokenRenameTransportDecorator for testing with."""
 
142
 
 
143
    def get_decorator_class(self):
 
144
        from breezy.transport import brokenrename
 
145
        return brokenrename.BrokenRenameTransportDecorator
 
146
 
 
147
 
 
148
class FakeNFSServer(DecoratorServer):
 
149
    """Server for the FakeNFSTransportDecorator for testing with."""
 
150
 
 
151
    def get_decorator_class(self):
 
152
        from breezy.transport import fakenfs
 
153
        return fakenfs.FakeNFSTransportDecorator
 
154
 
 
155
 
 
156
class FakeVFATServer(DecoratorServer):
 
157
    """A server that suggests connections through FakeVFATTransportDecorator
 
158
 
 
159
    For use in testing.
 
160
    """
 
161
 
 
162
    def get_decorator_class(self):
 
163
        from breezy.transport import fakevfat
 
164
        return fakevfat.FakeVFATTransportDecorator
 
165
 
 
166
 
 
167
class LogDecoratorServer(DecoratorServer):
 
168
    """Server for testing."""
 
169
 
 
170
    def get_decorator_class(self):
 
171
        from breezy.transport import log
 
172
        return log.TransportLogDecorator
 
173
 
 
174
 
 
175
class NoSmartTransportServer(DecoratorServer):
 
176
    """Server for the NoSmartTransportDecorator for testing with."""
 
177
 
 
178
    def get_decorator_class(self):
 
179
        from breezy.transport import nosmart
 
180
        return nosmart.NoSmartTransportDecorator
 
181
 
 
182
 
 
183
class ReadonlyServer(DecoratorServer):
 
184
    """Server for the ReadonlyTransportDecorator for testing with."""
 
185
 
 
186
    def get_decorator_class(self):
 
187
        from breezy.transport import readonly
 
188
        return readonly.ReadonlyTransportDecorator
 
189
 
 
190
 
 
191
class TraceServer(DecoratorServer):
 
192
    """Server for the TransportTraceDecorator for testing with."""
 
193
 
 
194
    def get_decorator_class(self):
 
195
        from breezy.transport import trace
 
196
        return trace.TransportTraceDecorator
 
197
 
 
198
 
 
199
class UnlistableServer(DecoratorServer):
 
200
    """Server for the UnlistableTransportDecorator for testing with."""
 
201
 
 
202
    def get_decorator_class(self):
 
203
        from breezy.transport import unlistable
 
204
        return unlistable.UnlistableTransportDecorator
 
205
 
 
206
 
 
207
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
 
208
 
 
209
    def __init__(self):
 
210
        """TestingPathFilteringServer is not usable until start_server
 
211
        is called."""
 
212
 
 
213
    def start_server(self, backing_server=None):
 
214
        """Setup the Chroot on backing_server."""
 
215
        if backing_server is not None:
 
216
            self.backing_transport = transport.get_transport_from_url(
 
217
                backing_server.get_url())
 
218
        else:
 
219
            self.backing_transport = transport.get_transport_from_path('.')
 
220
        self.backing_transport.clone('added-by-filter').ensure_base()
 
221
        self.filter_func = lambda x: 'added-by-filter/' + x
 
222
        super(TestingPathFilteringServer, self).start_server()
 
223
 
 
224
    def get_bogus_url(self):
 
225
        raise NotImplementedError
 
226
 
 
227
 
 
228
class TestingChrootServer(chroot.ChrootServer):
 
229
 
 
230
    def __init__(self):
 
231
        """TestingChrootServer is not usable until start_server is called."""
 
232
        super(TestingChrootServer, self).__init__(None)
 
233
 
 
234
    def start_server(self, backing_server=None):
 
235
        """Setup the Chroot on backing_server."""
 
236
        if backing_server is not None:
 
237
            self.backing_transport = transport.get_transport_from_url(
 
238
                backing_server.get_url())
 
239
        else:
 
240
            self.backing_transport = transport.get_transport_from_path('.')
 
241
        super(TestingChrootServer, self).start_server()
 
242
 
 
243
    def get_bogus_url(self):
 
244
        raise NotImplementedError
 
245
 
 
246
 
 
247
class TestThread(cethread.CatchingExceptionThread):
 
248
 
 
249
    if not getattr(cethread.CatchingExceptionThread, 'is_alive', None):
 
250
        def is_alive(self):
 
251
            return self.isAlive()
 
252
 
 
253
    def join(self, timeout=5):
 
254
        """Overrides to use a default timeout.
 
255
 
 
256
        The default timeout is set to 5 and should expire only when a thread
 
257
        serving a client connection is hung.
 
258
        """
 
259
        super(TestThread, self).join(timeout)
 
260
        if timeout and self.is_alive():
 
261
            # The timeout expired without joining the thread, the thread is
 
262
            # therefore stucked and that's a failure as far as the test is
 
263
            # concerned. We used to hang here.
 
264
 
 
265
            # FIXME: we need to kill the thread, but as far as the test is
 
266
            # concerned, raising an assertion is too strong. On most of the
 
267
            # platforms, this doesn't occur, so just mentioning the problem is
 
268
            # enough for now -- vila 2010824
 
269
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
270
            # raise AssertionError('thread %s hung' % (self.name,))
 
271
 
 
272
 
 
273
class TestingTCPServerMixin(object):
 
274
    """Mixin to support running socketserver.TCPServer in a thread.
 
275
 
 
276
    Tests are connecting from the main thread, the server has to be run in a
 
277
    separate thread.
 
278
    """
 
279
 
 
280
    def __init__(self):
 
281
        self.started = threading.Event()
 
282
        self.serving = None
 
283
        self.stopped = threading.Event()
 
284
        # We collect the resources used by the clients so we can release them
 
285
        # when shutting down
 
286
        self.clients = []
 
287
        self.ignored_exceptions = None
 
288
 
 
289
    def server_bind(self):
 
290
        self.socket.bind(self.server_address)
 
291
        self.server_address = self.socket.getsockname()
 
292
 
 
293
    def serve(self):
 
294
        self.serving = True
 
295
        # We are listening and ready to accept connections
 
296
        self.started.set()
 
297
        try:
 
298
            while self.serving:
 
299
                # Really a connection but the python framework is generic and
 
300
                # call them requests
 
301
                self.handle_request()
 
302
            # Let's close the listening socket
 
303
            self.server_close()
 
304
        finally:
 
305
            self.stopped.set()
 
306
 
 
307
    def handle_request(self):
 
308
        """Handle one request.
 
309
 
 
310
        The python version swallows some socket exceptions and we don't use
 
311
        timeout, so we override it to better control the server behavior.
 
312
        """
 
313
        request, client_address = self.get_request()
 
314
        if self.verify_request(request, client_address):
 
315
            try:
 
316
                self.process_request(request, client_address)
 
317
            except BaseException:
 
318
                self.handle_error(request, client_address)
 
319
        else:
 
320
            self.close_request(request)
 
321
 
 
322
    def get_request(self):
 
323
        return self.socket.accept()
 
324
 
 
325
    def verify_request(self, request, client_address):
 
326
        """Verify the request.
 
327
 
 
328
        Return True if we should proceed with this request, False if we should
 
329
        not even touch a single byte in the socket ! This is useful when we
 
330
        stop the server with a dummy last connection.
 
331
        """
 
332
        return self.serving
 
333
 
 
334
    def handle_error(self, request, client_address):
 
335
        # Stop serving and re-raise the last exception seen
 
336
        self.serving = False
 
337
        # The following can be used for debugging purposes, it will display the
 
338
        # exception and the traceback just when it occurs instead of waiting
 
339
        # for the thread to be joined.
 
340
        # socketserver.BaseServer.handle_error(self, request, client_address)
 
341
 
 
342
        # We call close_request manually, because we are going to raise an
 
343
        # exception. The socketserver implementation calls:
 
344
        #   handle_error(...)
 
345
        #   close_request(...)
 
346
        # But because we raise the exception, close_request will never be
 
347
        # triggered. This helps client not block waiting for a response when
 
348
        # the server gets an exception.
 
349
        self.close_request(request)
 
350
        raise
 
351
 
 
352
    def ignored_exceptions_during_shutdown(self, e):
 
353
        if sys.platform == 'win32':
 
354
            accepted_errnos = [errno.EBADF,
 
355
                               errno.EPIPE,
 
356
                               errno.WSAEBADF,
 
357
                               errno.WSAECONNRESET,
 
358
                               errno.WSAENOTCONN,
 
359
                               errno.WSAESHUTDOWN,
 
360
                               ]
 
361
        else:
 
362
            accepted_errnos = [errno.EBADF,
 
363
                               errno.ECONNRESET,
 
364
                               errno.ENOTCONN,
 
365
                               errno.EPIPE,
 
366
                               ]
 
367
        if isinstance(e, socket.error) and e.errno in accepted_errnos:
 
368
            return True
 
369
        return False
 
370
 
 
371
    # The following methods are called by the main thread
 
372
 
 
373
    def stop_client_connections(self):
 
374
        while self.clients:
 
375
            c = self.clients.pop()
 
376
            self.shutdown_client(c)
 
377
 
 
378
    def shutdown_socket(self, sock):
 
379
        """Properly shutdown a socket.
 
380
 
 
381
        This should be called only when no other thread is trying to use the
 
382
        socket.
 
383
        """
 
384
        try:
 
385
            sock.shutdown(socket.SHUT_RDWR)
 
386
            sock.close()
 
387
        except Exception as e:
 
388
            if self.ignored_exceptions(e):
 
389
                pass
 
390
            else:
 
391
                raise
 
392
 
 
393
    # The following methods are called by the main thread
 
394
 
 
395
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
396
        self.ignored_exceptions = ignored_exceptions
 
397
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
398
 
 
399
    def _pending_exception(self, thread):
 
400
        """Raise server uncaught exception.
 
401
 
 
402
        Daughter classes can override this if they use daughter threads.
 
403
        """
 
404
        thread.pending_exception()
 
405
 
 
406
 
 
407
class TestingTCPServer(TestingTCPServerMixin, socketserver.TCPServer):
 
408
 
 
409
    def __init__(self, server_address, request_handler_class):
 
410
        TestingTCPServerMixin.__init__(self)
 
411
        socketserver.TCPServer.__init__(self, server_address,
 
412
                                        request_handler_class)
 
413
 
 
414
    def get_request(self):
 
415
        """Get the request and client address from the socket."""
 
416
        sock, addr = TestingTCPServerMixin.get_request(self)
 
417
        self.clients.append((sock, addr))
 
418
        return sock, addr
 
419
 
 
420
    # The following methods are called by the main thread
 
421
 
 
422
    def shutdown_client(self, client):
 
423
        sock, addr = client
 
424
        self.shutdown_socket(sock)
 
425
 
 
426
 
 
427
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
428
                                socketserver.ThreadingTCPServer):
 
429
 
 
430
    def __init__(self, server_address, request_handler_class):
 
431
        TestingTCPServerMixin.__init__(self)
 
432
        socketserver.ThreadingTCPServer.__init__(self, server_address,
 
433
                                                 request_handler_class)
 
434
 
 
435
    def get_request(self):
 
436
        """Get the request and client address from the socket."""
 
437
        sock, addr = TestingTCPServerMixin.get_request(self)
 
438
        # The thread is not created yet, it will be updated in process_request
 
439
        self.clients.append((sock, addr, None))
 
440
        return sock, addr
 
441
 
 
442
    def process_request_thread(self, started, detached, stopped,
 
443
                               request, client_address):
 
444
        started.set()
 
445
        # We will be on our own once the server tells us we're detached
 
446
        detached.wait()
 
447
        socketserver.ThreadingTCPServer.process_request_thread(
 
448
            self, request, client_address)
 
449
        self.close_request(request)
 
450
        stopped.set()
 
451
 
 
452
    def process_request(self, request, client_address):
 
453
        """Start a new thread to process the request."""
 
454
        started = threading.Event()
 
455
        detached = threading.Event()
 
456
        stopped = threading.Event()
 
457
        t = TestThread(
 
458
            sync_event=stopped,
 
459
            name='%s -> %s' % (client_address, self.server_address),
 
460
            target=self.process_request_thread,
 
461
            args=(started, detached, stopped, request, client_address))
 
462
        # Update the client description
 
463
        self.clients.pop()
 
464
        self.clients.append((request, client_address, t))
 
465
        # Propagate the exception handler since we must use the same one as
 
466
        # TestingTCPServer for connections running in their own threads.
 
467
        t.set_ignored_exceptions(self.ignored_exceptions)
 
468
        t.start()
 
469
        started.wait()
 
470
        # If an exception occured during the thread start, it will get raised.
 
471
        t.pending_exception()
 
472
        if debug_threads():
 
473
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
474
        # Tell the thread, it's now on its own for exception handling.
 
475
        detached.set()
 
476
 
 
477
    # The following methods are called by the main thread
 
478
 
 
479
    def shutdown_client(self, client):
 
480
        sock, addr, connection_thread = client
 
481
        self.shutdown_socket(sock)
 
482
        if connection_thread is not None:
 
483
            # The thread has been created only if the request is processed but
 
484
            # after the connection is inited. This could happen during server
 
485
            # shutdown. If an exception occurred in the thread it will be
 
486
            # re-raised
 
487
            if debug_threads():
 
488
                sys.stderr.write('Client thread %s will be joined\n'
 
489
                                 % (connection_thread.name,))
 
490
            connection_thread.join()
 
491
 
 
492
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
493
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
494
                                                     ignored_exceptions)
 
495
        for sock, addr, connection_thread in self.clients:
 
496
            if connection_thread is not None:
 
497
                connection_thread.set_ignored_exceptions(
 
498
                    self.ignored_exceptions)
 
499
 
 
500
    def _pending_exception(self, thread):
 
501
        for sock, addr, connection_thread in self.clients:
 
502
            if connection_thread is not None:
 
503
                connection_thread.pending_exception()
 
504
        TestingTCPServerMixin._pending_exception(self, thread)
 
505
 
 
506
 
 
507
class TestingTCPServerInAThread(transport.Server):
 
508
    """A server in a thread that re-raise thread exceptions."""
 
509
 
 
510
    def __init__(self, server_address, server_class, request_handler_class):
 
511
        self.server_class = server_class
 
512
        self.request_handler_class = request_handler_class
 
513
        self.host, self.port = server_address
 
514
        self.server = None
 
515
        self._server_thread = None
 
516
 
 
517
    def __repr__(self):
 
518
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
519
 
 
520
    def create_server(self):
 
521
        return self.server_class((self.host, self.port),
 
522
                                 self.request_handler_class)
 
523
 
 
524
    def start_server(self):
 
525
        self.server = self.create_server()
 
526
        self._server_thread = TestThread(
 
527
            sync_event=self.server.started,
 
528
            target=self.run_server)
 
529
        self._server_thread.start()
 
530
        # Wait for the server thread to start (i.e. release the lock)
 
531
        self.server.started.wait()
 
532
        # Get the real address, especially the port
 
533
        self.host, self.port = self.server.server_address
 
534
        self._server_thread.name = self.server.server_address
 
535
        if debug_threads():
 
536
            sys.stderr.write('Server thread %s started\n'
 
537
                             % (self._server_thread.name,))
 
538
        # If an exception occured during the server start, it will get raised,
 
539
        # otherwise, the server is blocked on its accept() call.
 
540
        self._server_thread.pending_exception()
 
541
        # From now on, we'll use a different event to ensure the server can set
 
542
        # its exception
 
543
        self._server_thread.set_sync_event(self.server.stopped)
 
544
 
 
545
    def run_server(self):
 
546
        self.server.serve()
 
547
 
 
548
    def stop_server(self):
 
549
        if self.server is None:
 
550
            return
 
551
        try:
 
552
            # The server has been started successfully, shut it down now.  As
 
553
            # soon as we stop serving, no more connection are accepted except
 
554
            # one to get out of the blocking listen.
 
555
            self.set_ignored_exceptions(
 
556
                self.server.ignored_exceptions_during_shutdown)
 
557
            self.server.serving = False
 
558
            if debug_threads():
 
559
                sys.stderr.write('Server thread %s will be joined\n'
 
560
                                 % (self._server_thread.name,))
 
561
            # The server is listening for a last connection, let's give it:
 
562
            last_conn = None
 
563
            try:
 
564
                last_conn = osutils.connect_socket((self.host, self.port))
 
565
            except socket.error:
 
566
                # But ignore connection errors as the point is to unblock the
 
567
                # server thread, it may happen that it's not blocked or even
 
568
                # not started.
 
569
                pass
 
570
            # We start shutting down the clients while the server itself is
 
571
            # shutting down.
 
572
            self.server.stop_client_connections()
 
573
            # Now we wait for the thread running self.server.serve() to finish
 
574
            self.server.stopped.wait()
 
575
            if last_conn is not None:
 
576
                # Close the last connection without trying to use it. The
 
577
                # server will not process a single byte on that socket to avoid
 
578
                # complications (SSL starts with a handshake for example).
 
579
                last_conn.close()
 
580
            # Check for any exception that could have occurred in the server
 
581
            # thread
 
582
            try:
 
583
                self._server_thread.join()
 
584
            except Exception as e:
 
585
                if self.server.ignored_exceptions(e):
 
586
                    pass
 
587
                else:
 
588
                    raise
 
589
        finally:
 
590
            # Make sure we can be called twice safely, note that this means
 
591
            # that we will raise a single exception even if several occurred in
 
592
            # the various threads involved.
 
593
            self.server = None
 
594
 
 
595
    def set_ignored_exceptions(self, ignored_exceptions):
 
596
        """Install an exception handler for the server."""
 
597
        self.server.set_ignored_exceptions(self._server_thread,
 
598
                                           ignored_exceptions)
 
599
 
 
600
    def pending_exception(self):
 
601
        """Raise uncaught exception in the server."""
 
602
        self.server._pending_exception(self._server_thread)
 
603
 
 
604
 
 
605
class TestingSmartConnectionHandler(socketserver.BaseRequestHandler,
 
606
                                    medium.SmartServerSocketStreamMedium):
 
607
 
 
608
    def __init__(self, request, client_address, server):
 
609
        medium.SmartServerSocketStreamMedium.__init__(
 
610
            self, request, server.backing_transport,
 
611
            server.root_client_path,
 
612
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
613
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
614
        socketserver.BaseRequestHandler.__init__(self, request, client_address,
 
615
                                                 server)
 
616
 
 
617
    def handle(self):
 
618
        try:
 
619
            while not self.finished:
 
620
                server_protocol = self._build_protocol()
 
621
                self._serve_one_request(server_protocol)
 
622
        except errors.ConnectionTimeout:
 
623
            # idle connections aren't considered a failure of the server
 
624
            return
 
625
 
 
626
 
 
627
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
 
628
 
 
629
 
 
630
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
631
 
 
632
    def __init__(self, server_address, request_handler_class,
 
633
                 backing_transport, root_client_path):
 
634
        TestingThreadingTCPServer.__init__(self, server_address,
 
635
                                           request_handler_class)
 
636
        server.SmartTCPServer.__init__(
 
637
            self, backing_transport,
 
638
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
639
 
 
640
    def serve(self):
 
641
        self.run_server_started_hooks()
 
642
        try:
 
643
            TestingThreadingTCPServer.serve(self)
 
644
        finally:
 
645
            self.run_server_stopped_hooks()
 
646
 
 
647
    def get_url(self):
 
648
        """Return the url of the server"""
 
649
        return "bzr://%s:%d/" % self.server_address
 
650
 
 
651
 
 
652
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
 
653
    """Server suitable for use by transport tests.
 
654
 
 
655
    This server is backed by the process's cwd.
 
656
    """
 
657
 
 
658
    def __init__(self, thread_name_suffix=''):
 
659
        self.client_path_extra = None
 
660
        self.thread_name_suffix = thread_name_suffix
 
661
        self.host = '127.0.0.1'
 
662
        self.port = 0
 
663
        super(SmartTCPServer_for_testing, self).__init__(
 
664
            (self.host, self.port),
 
665
            TestingSmartServer,
 
666
            TestingSmartConnectionHandler)
 
667
 
 
668
    def create_server(self):
 
669
        return self.server_class((self.host, self.port),
 
670
                                 self.request_handler_class,
 
671
                                 self.backing_transport,
 
672
                                 self.root_client_path)
 
673
 
 
674
    def start_server(self, backing_transport_server=None,
 
675
                     client_path_extra='/extra/'):
 
676
        """Set up server for testing.
 
677
 
 
678
        :param backing_transport_server: backing server to use.  If not
 
679
            specified, a LocalURLServer at the current working directory will
 
680
            be used.
 
681
        :param client_path_extra: a path segment starting with '/' to append to
 
682
            the root URL for this server.  For instance, a value of '/foo/bar/'
 
683
            will mean the root of the backing transport will be published at a
 
684
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
 
685
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
 
686
            by default will fail unless they do the necessary path translation.
 
687
        """
 
688
        if not client_path_extra.startswith('/'):
 
689
            raise ValueError(client_path_extra)
 
690
        self.root_client_path = self.client_path_extra = client_path_extra
 
691
        from breezy.transport.chroot import ChrootServer
 
692
        if backing_transport_server is None:
 
693
            backing_transport_server = LocalURLServer()
 
694
        self.chroot_server = ChrootServer(
 
695
            self.get_backing_transport(backing_transport_server))
 
696
        self.chroot_server.start_server()
 
697
        self.backing_transport = transport.get_transport_from_url(
 
698
            self.chroot_server.get_url())
 
699
        super(SmartTCPServer_for_testing, self).start_server()
 
700
 
 
701
    def stop_server(self):
 
702
        try:
 
703
            super(SmartTCPServer_for_testing, self).stop_server()
 
704
        finally:
 
705
            self.chroot_server.stop_server()
 
706
 
 
707
    def get_backing_transport(self, backing_transport_server):
 
708
        """Get a backing transport from a server we are decorating."""
 
709
        return transport.get_transport_from_url(
 
710
            backing_transport_server.get_url())
 
711
 
 
712
    def get_url(self):
 
713
        url = self.server.get_url()
 
714
        return url[:-1] + self.client_path_extra
 
715
 
 
716
    def get_bogus_url(self):
 
717
        """Return a URL which will fail to connect"""
 
718
        return 'bzr://127.0.0.1:1/'
 
719
 
 
720
 
 
721
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
 
722
    """Get a readonly server for testing."""
 
723
 
 
724
    def get_backing_transport(self, backing_transport_server):
 
725
        """Get a backing transport from a server we are decorating."""
 
726
        url = 'readonly+' + backing_transport_server.get_url()
 
727
        return transport.get_transport_from_url(url)
 
728
 
 
729
 
 
730
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
 
731
    """A variation of SmartTCPServer_for_testing that limits the client to
 
732
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
 
733
    """
 
734
 
 
735
    def get_url(self):
 
736
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
 
737
        url = 'bzr-v2://' + url[len('bzr://'):]
 
738
        return url
 
739
 
 
740
 
 
741
class ReadonlySmartTCPServer_for_testing_v2_only(
 
742
        SmartTCPServer_for_testing_v2_only):
 
743
    """Get a readonly server for testing."""
 
744
 
 
745
    def get_backing_transport(self, backing_transport_server):
 
746
        """Get a backing transport from a server we are decorating."""
 
747
        url = 'readonly+' + backing_transport_server.get_url()
 
748
        return transport.get_transport_from_url(url)