/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Martin Pool
  • Date: 2011-08-24 09:34:35 UTC
  • mto: (6015.33.12 2.4)
  • mto: This revision was merged to the branch mainline in revision 6233.
  • Revision ID: mbp@canonical.com-20110824093435-h4tckaau084ywpcv
Correction to 'bzr serve' syntax in admin guide (thanks i41)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2007, 2008, 2010 Canonical Ltd
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import sys
 
21
import threading
 
22
 
 
23
 
17
24
from bzrlib import (
 
25
    cethread,
 
26
    osutils,
18
27
    transport,
19
28
    urlutils,
20
29
    )
22
31
    chroot,
23
32
    pathfilter,
24
33
    )
25
 
from bzrlib.smart import server
 
34
from bzrlib.smart import (
 
35
    medium,
 
36
    server,
 
37
    )
 
38
 
 
39
 
 
40
def debug_threads():
 
41
    # FIXME: There is a dependency loop between bzrlib.tests and
 
42
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
43
    # defining this function is enough for our needs. -- vila 20100611
 
44
    from bzrlib import tests
 
45
    return 'threads' in tests.selftest_debug_flags
26
46
 
27
47
 
28
48
class TestServer(transport.Server):
223
243
        raise NotImplementedError
224
244
 
225
245
 
226
 
class SmartTCPServer_for_testing(server.SmartTCPServer):
 
246
class TestThread(cethread.CatchingExceptionThread):
 
247
 
 
248
    def join(self, timeout=5):
 
249
        """Overrides to use a default timeout.
 
250
 
 
251
        The default timeout is set to 5 and should expire only when a thread
 
252
        serving a client connection is hung.
 
253
        """
 
254
        super(TestThread, self).join(timeout)
 
255
        if timeout and self.isAlive():
 
256
            # The timeout expired without joining the thread, the thread is
 
257
            # therefore stucked and that's a failure as far as the test is
 
258
            # concerned. We used to hang here.
 
259
 
 
260
            # FIXME: we need to kill the thread, but as far as the test is
 
261
            # concerned, raising an assertion is too strong. On most of the
 
262
            # platforms, this doesn't occur, so just mentioning the problem is
 
263
            # enough for now -- vila 2010824
 
264
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
265
            #raise AssertionError('thread %s hung' % (self.name,))
 
266
 
 
267
 
 
268
class TestingTCPServerMixin:
 
269
    """Mixin to support running SocketServer.TCPServer in a thread.
 
270
 
 
271
    Tests are connecting from the main thread, the server has to be run in a
 
272
    separate thread.
 
273
    """
 
274
 
 
275
    def __init__(self):
 
276
        self.started = threading.Event()
 
277
        self.serving = None
 
278
        self.stopped = threading.Event()
 
279
        # We collect the resources used by the clients so we can release them
 
280
        # when shutting down
 
281
        self.clients = []
 
282
        self.ignored_exceptions = None
 
283
 
 
284
    def server_bind(self):
 
285
        self.socket.bind(self.server_address)
 
286
        self.server_address = self.socket.getsockname()
 
287
 
 
288
    def serve(self):
 
289
        self.serving = True
 
290
        # We are listening and ready to accept connections
 
291
        self.started.set()
 
292
        try:
 
293
            while self.serving:
 
294
                # Really a connection but the python framework is generic and
 
295
                # call them requests
 
296
                self.handle_request()
 
297
            # Let's close the listening socket
 
298
            self.server_close()
 
299
        finally:
 
300
            self.stopped.set()
 
301
 
 
302
    def handle_request(self):
 
303
        """Handle one request.
 
304
 
 
305
        The python version swallows some socket exceptions and we don't use
 
306
        timeout, so we override it to better control the server behavior.
 
307
        """
 
308
        request, client_address = self.get_request()
 
309
        if self.verify_request(request, client_address):
 
310
            try:
 
311
                self.process_request(request, client_address)
 
312
            except:
 
313
                self.handle_error(request, client_address)
 
314
                self.close_request(request)
 
315
 
 
316
    def get_request(self):
 
317
        return self.socket.accept()
 
318
 
 
319
    def verify_request(self, request, client_address):
 
320
        """Verify the request.
 
321
 
 
322
        Return True if we should proceed with this request, False if we should
 
323
        not even touch a single byte in the socket ! This is useful when we
 
324
        stop the server with a dummy last connection.
 
325
        """
 
326
        return self.serving
 
327
 
 
328
    def handle_error(self, request, client_address):
 
329
        # Stop serving and re-raise the last exception seen
 
330
        self.serving = False
 
331
        # The following can be used for debugging purposes, it will display the
 
332
        # exception and the traceback just when it occurs instead of waiting
 
333
        # for the thread to be joined.
 
334
 
 
335
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
336
        raise
 
337
 
 
338
    def ignored_exceptions_during_shutdown(self, e):
 
339
        if sys.platform == 'win32':
 
340
            accepted_errnos = [errno.EBADF,
 
341
                               errno.EPIPE,
 
342
                               errno.WSAEBADF,
 
343
                               errno.WSAECONNRESET,
 
344
                               errno.WSAENOTCONN,
 
345
                               errno.WSAESHUTDOWN,
 
346
                               ]
 
347
        else:
 
348
            accepted_errnos = [errno.EBADF,
 
349
                               errno.ECONNRESET,
 
350
                               errno.ENOTCONN,
 
351
                               errno.EPIPE,
 
352
                               ]
 
353
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
354
            return True
 
355
        return False
 
356
 
 
357
    # The following methods are called by the main thread
 
358
 
 
359
    def stop_client_connections(self):
 
360
        while self.clients:
 
361
            c = self.clients.pop()
 
362
            self.shutdown_client(c)
 
363
 
 
364
    def shutdown_socket(self, sock):
 
365
        """Properly shutdown a socket.
 
366
 
 
367
        This should be called only when no other thread is trying to use the
 
368
        socket.
 
369
        """
 
370
        try:
 
371
            sock.shutdown(socket.SHUT_RDWR)
 
372
            sock.close()
 
373
        except Exception, e:
 
374
            if self.ignored_exceptions(e):
 
375
                pass
 
376
            else:
 
377
                raise
 
378
 
 
379
    # The following methods are called by the main thread
 
380
 
 
381
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
382
        self.ignored_exceptions = ignored_exceptions
 
383
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
384
 
 
385
    def _pending_exception(self, thread):
 
386
        """Raise server uncaught exception.
 
387
 
 
388
        Daughter classes can override this if they use daughter threads.
 
389
        """
 
390
        thread.pending_exception()
 
391
 
 
392
 
 
393
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
394
 
 
395
    def __init__(self, server_address, request_handler_class):
 
396
        TestingTCPServerMixin.__init__(self)
 
397
        SocketServer.TCPServer.__init__(self, server_address,
 
398
                                        request_handler_class)
 
399
 
 
400
    def get_request(self):
 
401
        """Get the request and client address from the socket."""
 
402
        sock, addr = TestingTCPServerMixin.get_request(self)
 
403
        self.clients.append((sock, addr))
 
404
        return sock, addr
 
405
 
 
406
    # The following methods are called by the main thread
 
407
 
 
408
    def shutdown_client(self, client):
 
409
        sock, addr = client
 
410
        self.shutdown_socket(sock)
 
411
 
 
412
 
 
413
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
414
                                SocketServer.ThreadingTCPServer):
 
415
 
 
416
    def __init__(self, server_address, request_handler_class):
 
417
        TestingTCPServerMixin.__init__(self)
 
418
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
419
                                                 request_handler_class)
 
420
 
 
421
    def get_request (self):
 
422
        """Get the request and client address from the socket."""
 
423
        sock, addr = TestingTCPServerMixin.get_request(self)
 
424
        # The thread is not create yet, it will be updated in process_request
 
425
        self.clients.append((sock, addr, None))
 
426
        return sock, addr
 
427
 
 
428
    def process_request_thread(self, started, stopped, request, client_address):
 
429
        started.set()
 
430
        SocketServer.ThreadingTCPServer.process_request_thread(
 
431
            self, request, client_address)
 
432
        self.close_request(request)
 
433
        stopped.set()
 
434
 
 
435
    def process_request(self, request, client_address):
 
436
        """Start a new thread to process the request."""
 
437
        started = threading.Event()
 
438
        stopped = threading.Event()
 
439
        t = TestThread(
 
440
            sync_event=stopped,
 
441
            name='%s -> %s' % (client_address, self.server_address),
 
442
            target = self.process_request_thread,
 
443
            args = (started, stopped, request, client_address))
 
444
        # Update the client description
 
445
        self.clients.pop()
 
446
        self.clients.append((request, client_address, t))
 
447
        # Propagate the exception handler since we must use the same one as
 
448
        # TestingTCPServer for connections running in their own threads.
 
449
        t.set_ignored_exceptions(self.ignored_exceptions)
 
450
        t.start()
 
451
        started.wait()
 
452
        if debug_threads():
 
453
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
454
        # If an exception occured during the thread start, it will get raised.
 
455
        # In rare cases, an exception raised during the request processing may
 
456
        # also get caught here (see http://pad.lv/869366)
 
457
        t.pending_exception()
 
458
 
 
459
    # The following methods are called by the main thread
 
460
 
 
461
    def shutdown_client(self, client):
 
462
        sock, addr, connection_thread = client
 
463
        self.shutdown_socket(sock)
 
464
        if connection_thread is not None:
 
465
            # The thread has been created only if the request is processed but
 
466
            # after the connection is inited. This could happen during server
 
467
            # shutdown. If an exception occurred in the thread it will be
 
468
            # re-raised
 
469
            if debug_threads():
 
470
                sys.stderr.write('Client thread %s will be joined\n'
 
471
                                 % (connection_thread.name,))
 
472
            connection_thread.join()
 
473
 
 
474
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
475
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
476
                                                     ignored_exceptions)
 
477
        for sock, addr, connection_thread in self.clients:
 
478
            if connection_thread is not None:
 
479
                connection_thread.set_ignored_exceptions(
 
480
                    self.ignored_exceptions)
 
481
 
 
482
    def _pending_exception(self, thread):
 
483
        for sock, addr, connection_thread in self.clients:
 
484
            if connection_thread is not None:
 
485
                connection_thread.pending_exception()
 
486
        TestingTCPServerMixin._pending_exception(self, thread)
 
487
 
 
488
 
 
489
class TestingTCPServerInAThread(transport.Server):
 
490
    """A server in a thread that re-raise thread exceptions."""
 
491
 
 
492
    def __init__(self, server_address, server_class, request_handler_class):
 
493
        self.server_class = server_class
 
494
        self.request_handler_class = request_handler_class
 
495
        self.host, self.port = server_address
 
496
        self.server = None
 
497
        self._server_thread = None
 
498
 
 
499
    def __repr__(self):
 
500
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
501
 
 
502
    def create_server(self):
 
503
        return self.server_class((self.host, self.port),
 
504
                                 self.request_handler_class)
 
505
 
 
506
    def start_server(self):
 
507
        self.server = self.create_server()
 
508
        self._server_thread = TestThread(
 
509
            sync_event=self.server.started,
 
510
            target=self.run_server)
 
511
        self._server_thread.start()
 
512
        # Wait for the server thread to start (i.e. release the lock)
 
513
        self.server.started.wait()
 
514
        # Get the real address, especially the port
 
515
        self.host, self.port = self.server.server_address
 
516
        self._server_thread.name = self.server.server_address
 
517
        if debug_threads():
 
518
            sys.stderr.write('Server thread %s started\n'
 
519
                             % (self._server_thread.name,))
 
520
        # If an exception occured during the server start, it will get raised,
 
521
        # otherwise, the server is blocked on its accept() call.
 
522
        self._server_thread.pending_exception()
 
523
        # From now on, we'll use a different event to ensure the server can set
 
524
        # its exception
 
525
        self._server_thread.set_sync_event(self.server.stopped)
 
526
 
 
527
    def run_server(self):
 
528
        self.server.serve()
 
529
 
 
530
    def stop_server(self):
 
531
        if self.server is None:
 
532
            return
 
533
        try:
 
534
            # The server has been started successfully, shut it down now.  As
 
535
            # soon as we stop serving, no more connection are accepted except
 
536
            # one to get out of the blocking listen.
 
537
            self.set_ignored_exceptions(
 
538
                self.server.ignored_exceptions_during_shutdown)
 
539
            self.server.serving = False
 
540
            if debug_threads():
 
541
                sys.stderr.write('Server thread %s will be joined\n'
 
542
                                 % (self._server_thread.name,))
 
543
            # The server is listening for a last connection, let's give it:
 
544
            last_conn = None
 
545
            try:
 
546
                last_conn = osutils.connect_socket((self.host, self.port))
 
547
            except socket.error, e:
 
548
                # But ignore connection errors as the point is to unblock the
 
549
                # server thread, it may happen that it's not blocked or even
 
550
                # not started.
 
551
                pass
 
552
            # We start shutting down the clients while the server itself is
 
553
            # shutting down.
 
554
            self.server.stop_client_connections()
 
555
            # Now we wait for the thread running self.server.serve() to finish
 
556
            self.server.stopped.wait()
 
557
            if last_conn is not None:
 
558
                # Close the last connection without trying to use it. The
 
559
                # server will not process a single byte on that socket to avoid
 
560
                # complications (SSL starts with a handshake for example).
 
561
                last_conn.close()
 
562
            # Check for any exception that could have occurred in the server
 
563
            # thread
 
564
            try:
 
565
                self._server_thread.join()
 
566
            except Exception, e:
 
567
                if self.server.ignored_exceptions(e):
 
568
                    pass
 
569
                else:
 
570
                    raise
 
571
        finally:
 
572
            # Make sure we can be called twice safely, note that this means
 
573
            # that we will raise a single exception even if several occurred in
 
574
            # the various threads involved.
 
575
            self.server = None
 
576
 
 
577
    def set_ignored_exceptions(self, ignored_exceptions):
 
578
        """Install an exception handler for the server."""
 
579
        self.server.set_ignored_exceptions(self._server_thread,
 
580
                                           ignored_exceptions)
 
581
 
 
582
    def pending_exception(self):
 
583
        """Raise uncaught exception in the server."""
 
584
        self.server._pending_exception(self._server_thread)
 
585
 
 
586
 
 
587
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
588
                                    medium.SmartServerSocketStreamMedium):
 
589
 
 
590
    def __init__(self, request, client_address, server):
 
591
        medium.SmartServerSocketStreamMedium.__init__(
 
592
            self, request, server.backing_transport,
 
593
            server.root_client_path)
 
594
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
595
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
596
                                                 server)
 
597
 
 
598
    def handle(self):
 
599
        while not self.finished:
 
600
            server_protocol = self._build_protocol()
 
601
            self._serve_one_request(server_protocol)
 
602
 
 
603
 
 
604
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
605
 
 
606
    def __init__(self, server_address, request_handler_class,
 
607
                 backing_transport, root_client_path):
 
608
        TestingThreadingTCPServer.__init__(self, server_address,
 
609
                                           request_handler_class)
 
610
        server.SmartTCPServer.__init__(self, backing_transport,
 
611
                                       root_client_path)
 
612
    def serve(self):
 
613
        self.run_server_started_hooks()
 
614
        try:
 
615
            TestingThreadingTCPServer.serve(self)
 
616
        finally:
 
617
            self.run_server_stopped_hooks()
 
618
 
 
619
    def get_url(self):
 
620
        """Return the url of the server"""
 
621
        return "bzr://%s:%d/" % self.server_address
 
622
 
 
623
 
 
624
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
227
625
    """Server suitable for use by transport tests.
228
626
 
229
627
    This server is backed by the process's cwd.
230
628
    """
231
 
 
232
629
    def __init__(self, thread_name_suffix=''):
233
 
        super(SmartTCPServer_for_testing, self).__init__(None)
234
630
        self.client_path_extra = None
235
631
        self.thread_name_suffix = thread_name_suffix
236
 
 
237
 
    def get_backing_transport(self, backing_transport_server):
238
 
        """Get a backing transport from a server we are decorating."""
239
 
        return transport.get_transport(backing_transport_server.get_url())
 
632
        self.host = '127.0.0.1'
 
633
        self.port = 0
 
634
        super(SmartTCPServer_for_testing, self).__init__(
 
635
                (self.host, self.port),
 
636
                TestingSmartServer,
 
637
                TestingSmartConnectionHandler)
 
638
 
 
639
    def create_server(self):
 
640
        return self.server_class((self.host, self.port),
 
641
                                 self.request_handler_class,
 
642
                                 self.backing_transport,
 
643
                                 self.root_client_path)
 
644
 
240
645
 
241
646
    def start_server(self, backing_transport_server=None,
242
 
              client_path_extra='/extra/'):
 
647
                     client_path_extra='/extra/'):
243
648
        """Set up server for testing.
244
649
 
245
650
        :param backing_transport_server: backing server to use.  If not
254
659
        """
255
660
        if not client_path_extra.startswith('/'):
256
661
            raise ValueError(client_path_extra)
 
662
        self.root_client_path = self.client_path_extra = client_path_extra
257
663
        from bzrlib.transport.chroot import ChrootServer
258
664
        if backing_transport_server is None:
259
665
            backing_transport_server = LocalURLServer()
262
668
        self.chroot_server.start_server()
263
669
        self.backing_transport = transport.get_transport(
264
670
            self.chroot_server.get_url())
265
 
        self.root_client_path = self.client_path_extra = client_path_extra
266
 
        self.start_background_thread(self.thread_name_suffix)
 
671
        super(SmartTCPServer_for_testing, self).start_server()
267
672
 
268
673
    def stop_server(self):
269
 
        self.stop_background_thread()
270
 
        self.chroot_server.stop_server()
 
674
        try:
 
675
            super(SmartTCPServer_for_testing, self).stop_server()
 
676
        finally:
 
677
            self.chroot_server.stop_server()
 
678
 
 
679
    def get_backing_transport(self, backing_transport_server):
 
680
        """Get a backing transport from a server we are decorating."""
 
681
        return transport.get_transport(backing_transport_server.get_url())
271
682
 
272
683
    def get_url(self):
273
 
        url = super(SmartTCPServer_for_testing, self).get_url()
 
684
        url = self.server.get_url()
274
685
        return url[:-1] + self.client_path_extra
275
686
 
276
687
    def get_bogus_url(self):
306
717
        """Get a backing transport from a server we are decorating."""
307
718
        url = 'readonly+' + backing_transport_server.get_url()
308
719
        return transport.get_transport(url)
309
 
 
310
 
 
311
 
 
312