/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Vincent Ladeuil
  • Date: 2011-11-24 15:48:29 UTC
  • mfrom: (6289 +trunk)
  • mto: This revision was merged to the branch mainline in revision 6337.
  • Revision ID: v.ladeuil+lp@free.fr-20111124154829-avowjpsxdl8yp2vz
merge trunk resolving conflicts

Show diffs side-by-side

added added

removed removed

Lines of Context:
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
21
22
import os
22
23
import socket
 
24
import sys
23
25
import threading
 
26
import time
 
27
 
 
28
from testtools.matchers import DocTestMatches
24
29
 
25
30
import bzrlib
26
31
from bzrlib import (
28
33
        errors,
29
34
        osutils,
30
35
        tests,
31
 
        transport,
 
36
        transport as _mod_transport,
32
37
        urlutils,
33
38
        )
34
39
from bzrlib.smart import (
37
42
        message,
38
43
        protocol,
39
44
        request as _mod_request,
40
 
        server,
 
45
        server as _mod_server,
41
46
        vfs,
42
47
)
43
48
from bzrlib.tests import (
 
49
    features,
44
50
    test_smart,
45
51
    test_server,
46
52
    )
53
59
        )
54
60
 
55
61
 
 
62
def portable_socket_pair():
 
63
    """Return a pair of TCP sockets connected to each other.
 
64
 
 
65
    Unlike socket.socketpair, this should work on Windows.
 
66
    """
 
67
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
68
    listen_sock.bind(('127.0.0.1', 0))
 
69
    listen_sock.listen(1)
 
70
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
71
    client_sock.connect(listen_sock.getsockname())
 
72
    server_sock, addr = listen_sock.accept()
 
73
    listen_sock.close()
 
74
    return server_sock, client_sock
 
75
 
 
76
 
56
77
class StringIOSSHVendor(object):
57
78
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
58
79
 
82
103
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
83
104
 
84
105
 
85
 
class _InvalidHostnameFeature(tests.Feature):
 
106
class _InvalidHostnameFeature(features.Feature):
86
107
    """Does 'non_existent.invalid' fail to resolve?
87
108
 
88
109
    RFC 2606 states that .invalid is reserved for invalid domain names, and
338
359
            ],
339
360
            vendor.calls)
340
361
 
 
362
    def test_ssh_client_repr(self):
 
363
        client_medium = medium.SmartSSHClientMedium(
 
364
            'base', medium.SSHParams("example.com", "4242", "username"))
 
365
        self.assertEquals(
 
366
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
 
367
            repr(client_medium))
 
368
 
 
369
    def test_ssh_client_repr_no_port(self):
 
370
        client_medium = medium.SmartSSHClientMedium(
 
371
            'base', medium.SSHParams("example.com", None, "username"))
 
372
        self.assertEquals(
 
373
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
 
374
            repr(client_medium))
 
375
 
 
376
    def test_ssh_client_repr_no_username(self):
 
377
        client_medium = medium.SmartSSHClientMedium(
 
378
            'base', medium.SSHParams("example.com", None, None))
 
379
        self.assertEquals(
 
380
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
 
381
            repr(client_medium))
 
382
 
341
383
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
342
384
        # Doing a disconnect on a new (and thus unconnected) SSH medium
343
385
        # does not fail.  It's ok to disconnect an unconnected medium.
617
659
        super(TestSmartServerStreamMedium, self).setUp()
618
660
        self.overrideEnv('BZR_NO_SMART_VFS', None)
619
661
 
620
 
    def portable_socket_pair(self):
621
 
        """Return a pair of TCP sockets connected to each other.
622
 
 
623
 
        Unlike socket.socketpair, this should work on Windows.
624
 
        """
625
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
626
 
        listen_sock.bind(('127.0.0.1', 0))
627
 
        listen_sock.listen(1)
628
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
629
 
        client_sock.connect(listen_sock.getsockname())
630
 
        server_sock, addr = listen_sock.accept()
631
 
        listen_sock.close()
632
 
        return server_sock, client_sock
 
662
    def create_pipe_medium(self, to_server, from_server, transport,
 
663
                           timeout=4.0):
 
664
        """Create a new SmartServerPipeStreamMedium."""
 
665
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
666
            transport, timeout=timeout)
 
667
 
 
668
    def create_pipe_context(self, to_server_bytes, transport):
 
669
        """Create a SmartServerSocketStreamMedium.
 
670
 
 
671
        This differes from create_pipe_medium, in that we initialize the
 
672
        request that is sent to the server, and return the StringIO class that
 
673
        will hold the response.
 
674
        """
 
675
        to_server = StringIO(to_server_bytes)
 
676
        from_server = StringIO()
 
677
        m = self.create_pipe_medium(to_server, from_server, transport)
 
678
        return m, from_server
 
679
 
 
680
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
681
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
682
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
683
            timeout=timeout)
 
684
 
 
685
    def create_socket_context(self, transport, timeout=4.0):
 
686
        """Create a new SmartServerSocketStreamMedium with default context.
 
687
 
 
688
        This will call portable_socket_pair and pass the server side to
 
689
        create_socket_medium along with transport.
 
690
        It then returns the client_sock and the server.
 
691
        """
 
692
        server_sock, client_sock = portable_socket_pair()
 
693
        server = self.create_socket_medium(server_sock, transport,
 
694
                                           timeout=timeout)
 
695
        return server, client_sock
633
696
 
634
697
    def test_smart_query_version(self):
635
698
        """Feed a canned query version to a server"""
636
699
        # wire-to-wire, using the whole stack
637
 
        to_server = StringIO('hello\n')
638
 
        from_server = StringIO()
639
700
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
640
 
        server = medium.SmartServerPipeStreamMedium(
641
 
            to_server, from_server, transport)
 
701
        server, from_server = self.create_pipe_context('hello\n', transport)
642
702
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
643
703
                from_server.write)
644
704
        server._serve_one_request(smart_protocol)
648
708
    def test_response_to_canned_get(self):
649
709
        transport = memory.MemoryTransport('memory:///')
650
710
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
651
 
        to_server = StringIO('get\001./testfile\n')
652
 
        from_server = StringIO()
653
 
        server = medium.SmartServerPipeStreamMedium(
654
 
            to_server, from_server, transport)
 
711
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
712
            transport)
655
713
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
656
714
                from_server.write)
657
715
        server._serve_one_request(smart_protocol)
668
726
        # VFS requests use filenames, not raw UTF-8.
669
727
        hpss_path = urlutils.escape(utf8_filename)
670
728
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
671
 
        to_server = StringIO('get\001' + hpss_path + '\n')
672
 
        from_server = StringIO()
673
 
        server = medium.SmartServerPipeStreamMedium(
674
 
            to_server, from_server, transport)
 
729
        server, from_server = self.create_pipe_context(
 
730
                'get\001' + hpss_path + '\n', transport)
675
731
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
676
732
                from_server.write)
677
733
        server._serve_one_request(smart_protocol)
683
739
 
684
740
    def test_pipe_like_stream_with_bulk_data(self):
685
741
        sample_request_bytes = 'command\n9\nbulk datadone\n'
686
 
        to_server = StringIO(sample_request_bytes)
687
 
        from_server = StringIO()
688
 
        server = medium.SmartServerPipeStreamMedium(
689
 
            to_server, from_server, None)
 
742
        server, from_server = self.create_pipe_context(
 
743
            sample_request_bytes, None)
690
744
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
691
745
        server._serve_one_request(sample_protocol)
692
746
        self.assertEqual('', from_server.getvalue())
695
749
 
696
750
    def test_socket_stream_with_bulk_data(self):
697
751
        sample_request_bytes = 'command\n9\nbulk datadone\n'
698
 
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = medium.SmartServerSocketStreamMedium(
700
 
            server_sock, None)
 
752
        server, client_sock = self.create_socket_context(None)
701
753
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
754
        client_sock.sendall(sample_request_bytes)
703
755
        server._serve_one_request(sample_protocol)
704
 
        server_sock.close()
 
756
        server._disconnect_client()
705
757
        self.assertEqual('', client_sock.recv(1))
706
758
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
707
759
        self.assertFalse(server.finished)
708
760
 
709
761
    def test_pipe_like_stream_shutdown_detection(self):
710
 
        to_server = StringIO('')
711
 
        from_server = StringIO()
712
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
762
        server, _ = self.create_pipe_context('', None)
713
763
        server._serve_one_request(SampleRequest('x'))
714
764
        self.assertTrue(server.finished)
715
765
 
716
766
    def test_socket_stream_shutdown_detection(self):
717
 
        server_sock, client_sock = self.portable_socket_pair()
 
767
        server, client_sock = self.create_socket_context(None)
718
768
        client_sock.close()
719
 
        server = medium.SmartServerSocketStreamMedium(
720
 
            server_sock, None)
721
769
        server._serve_one_request(SampleRequest('x'))
722
770
        self.assertTrue(server.finished)
723
771
 
734
782
        rest_of_request_bytes = 'lo\n'
735
783
        expected_response = (
736
784
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
737
 
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = medium.SmartServerSocketStreamMedium(
739
 
            server_sock, None)
 
785
        server, client_sock = self.create_socket_context(None)
740
786
        client_sock.sendall(incomplete_request_bytes)
741
787
        server_protocol = server._build_protocol()
742
788
        client_sock.sendall(rest_of_request_bytes)
743
789
        server._serve_one_request(server_protocol)
744
 
        server_sock.close()
 
790
        server._disconnect_client()
745
791
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
746
792
                         "Not a version 2 response to 'hello' request.")
747
793
        self.assertEqual('', client_sock.recv(1))
766
812
        to_server_w = os.fdopen(to_server_w, 'w', 0)
767
813
        from_server_r = os.fdopen(from_server_r, 'r', 0)
768
814
        from_server = os.fdopen(from_server, 'w', 0)
769
 
        server = medium.SmartServerPipeStreamMedium(
770
 
            to_server, from_server, None)
 
815
        server = self.create_pipe_medium(to_server, from_server, None)
771
816
        # Like test_socket_stream_incomplete_request, write an incomplete
772
817
        # request (that does not end in '\n') and build a protocol from it.
773
818
        to_server_w.write(incomplete_request_bytes)
788
833
        # _serve_one_request should still process both of them as if they had
789
834
        # been received separately.
790
835
        sample_request_bytes = 'command\n'
791
 
        to_server = StringIO(sample_request_bytes * 2)
792
 
        from_server = StringIO()
793
 
        server = medium.SmartServerPipeStreamMedium(
794
 
            to_server, from_server, None)
 
836
        server, from_server = self.create_pipe_context(
 
837
            sample_request_bytes * 2, None)
795
838
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
796
839
        server._serve_one_request(first_protocol)
797
840
        self.assertEqual(0, first_protocol.next_read_size())
810
853
        # _serve_one_request should still process both of them as if they had
811
854
        # been received separately.
812
855
        sample_request_bytes = 'command\n'
813
 
        server_sock, client_sock = self.portable_socket_pair()
814
 
        server = medium.SmartServerSocketStreamMedium(
815
 
            server_sock, None)
 
856
        server, client_sock = self.create_socket_context(None)
816
857
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
817
858
        # Put two whole requests on the wire.
818
859
        client_sock.sendall(sample_request_bytes * 2)
825
866
        stream_still_open = server._serve_one_request(second_protocol)
826
867
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
827
868
        self.assertFalse(server.finished)
828
 
        server_sock.close()
 
869
        server._disconnect_client()
829
870
        self.assertEqual('', client_sock.recv(1))
830
871
 
831
872
    def test_pipe_like_stream_error_handling(self):
838
879
        def close():
839
880
            self.closed = True
840
881
        from_server.close = close
841
 
        server = medium.SmartServerPipeStreamMedium(
 
882
        server = self.create_pipe_medium(
842
883
            to_server, from_server, None)
843
884
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
844
885
        server._serve_one_request(fake_protocol)
847
888
        self.assertTrue(server.finished)
848
889
 
849
890
    def test_socket_stream_error_handling(self):
850
 
        server_sock, client_sock = self.portable_socket_pair()
851
 
        server = medium.SmartServerSocketStreamMedium(
852
 
            server_sock, None)
 
891
        server, client_sock = self.create_socket_context(None)
853
892
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
854
893
        server._serve_one_request(fake_protocol)
855
894
        # recv should not block, because the other end of the socket has been
858
897
        self.assertTrue(server.finished)
859
898
 
860
899
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
861
 
        to_server = StringIO('')
862
 
        from_server = StringIO()
863
 
        server = medium.SmartServerPipeStreamMedium(
864
 
            to_server, from_server, None)
 
900
        server, from_server = self.create_pipe_context('', None)
865
901
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
866
902
        self.assertRaises(
867
903
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
868
904
        self.assertEqual('', from_server.getvalue())
869
905
 
870
906
    def test_socket_stream_keyboard_interrupt_handling(self):
871
 
        server_sock, client_sock = self.portable_socket_pair()
872
 
        server = medium.SmartServerSocketStreamMedium(
873
 
            server_sock, None)
 
907
        server, client_sock = self.create_socket_context(None)
874
908
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
875
909
        self.assertRaises(
876
910
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
877
 
        server_sock.close()
 
911
        server._disconnect_client()
878
912
        self.assertEqual('', client_sock.recv(1))
879
913
 
880
914
    def build_protocol_pipe_like(self, bytes):
881
 
        to_server = StringIO(bytes)
882
 
        from_server = StringIO()
883
 
        server = medium.SmartServerPipeStreamMedium(
884
 
            to_server, from_server, None)
 
915
        server, _ = self.create_pipe_context(bytes, None)
885
916
        return server._build_protocol()
886
917
 
887
918
    def build_protocol_socket(self, bytes):
888
 
        server_sock, client_sock = self.portable_socket_pair()
889
 
        server = medium.SmartServerSocketStreamMedium(
890
 
            server_sock, None)
 
919
        server, client_sock = self.create_socket_context(None)
891
920
        client_sock.sendall(bytes)
892
921
        client_sock.close()
893
922
        return server._build_protocol()
933
962
        server_protocol = self.build_protocol_socket('bzr request 2\n')
934
963
        self.assertProtocolTwo(server_protocol)
935
964
 
 
965
    def test__build_protocol_returns_if_stopping(self):
 
966
        # _build_protocol should notice that we are stopping, and return
 
967
        # without waiting for bytes from the client.
 
968
        server, client_sock = self.create_socket_context(None)
 
969
        server._stop_gracefully()
 
970
        self.assertIs(None, server._build_protocol())
 
971
 
 
972
    def test_socket_set_timeout(self):
 
973
        server, _ = self.create_socket_context(None, timeout=1.23)
 
974
        self.assertEqual(1.23, server._client_timeout)
 
975
 
 
976
    def test_pipe_set_timeout(self):
 
977
        server = self.create_pipe_medium(None, None, None,
 
978
            timeout=1.23)
 
979
        self.assertEqual(1.23, server._client_timeout)
 
980
 
 
981
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
982
        server, client_sock = self.create_socket_context(None)
 
983
        client_sock.sendall('data\n')
 
984
        # This should not block or consume any actual content
 
985
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
986
        data = server.read_bytes(5)
 
987
        self.assertEqual('data\n', data)
 
988
 
 
989
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
990
        server, client_sock = self.create_socket_context(None)
 
991
        # This should timeout quickly, reporting that there wasn't any data
 
992
        self.assertRaises(errors.ConnectionTimeout,
 
993
                          server._wait_for_bytes_with_timeout, 0.01)
 
994
        client_sock.close()
 
995
        data = server.read_bytes(1)
 
996
        self.assertEqual('', data)
 
997
 
 
998
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
999
        server, client_sock = self.create_socket_context(None)
 
1000
        # With the socket closed, this should return right away.
 
1001
        # It seems select.select() returns that you *can* read on the socket,
 
1002
        # even though it closed. Presumably as a way to tell it is closed?
 
1003
        # Testing shows that without sock.close() this times-out failing the
 
1004
        # test, but with it, it returns False immediately.
 
1005
        client_sock.close()
 
1006
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
1007
        data = server.read_bytes(1)
 
1008
        self.assertEqual('', data)
 
1009
 
 
1010
    def test_socket_wait_for_bytes_with_shutdown(self):
 
1011
        server, client_sock = self.create_socket_context(None)
 
1012
        t = time.time()
 
1013
        # Override the _timer functionality, so that time never increments,
 
1014
        # this way, we can be sure we stopped because of the flag, and not
 
1015
        # because of a timeout, etc.
 
1016
        server._timer = lambda: t
 
1017
        server._client_poll_timeout = 0.1
 
1018
        server._stop_gracefully()
 
1019
        server._wait_for_bytes_with_timeout(1.0)
 
1020
 
 
1021
    def test_socket_serve_timeout_closes_socket(self):
 
1022
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1023
        # This should timeout quickly, and then close the connection so that
 
1024
        # client_sock recv doesn't block.
 
1025
        server.serve()
 
1026
        self.assertEqual('', client_sock.recv(1))
 
1027
 
 
1028
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1029
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1030
        # You can't select() on a StringIO
 
1031
        (r_server, w_client) = os.pipe()
 
1032
        self.addCleanup(os.close, w_client)
 
1033
        with os.fdopen(r_server, 'rb') as rf_server:
 
1034
            server = self.create_pipe_medium(
 
1035
                rf_server, None, None)
 
1036
            os.write(w_client, 'data\n')
 
1037
            # This should not block or consume any actual content
 
1038
            server._wait_for_bytes_with_timeout(0.1)
 
1039
            data = server.read_bytes(5)
 
1040
            self.assertEqual('data\n', data)
 
1041
 
 
1042
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1043
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1044
        # You can't select() on a StringIO
 
1045
        (r_server, w_client) = os.pipe()
 
1046
        # We can't add an os.close cleanup here, because we need to control
 
1047
        # when the file handle gets closed ourselves.
 
1048
        with os.fdopen(r_server, 'rb') as rf_server:
 
1049
            server = self.create_pipe_medium(
 
1050
                rf_server, None, None)
 
1051
            if sys.platform == 'win32':
 
1052
                # Windows cannot select() on a pipe, so we just always return
 
1053
                server._wait_for_bytes_with_timeout(0.01)
 
1054
            else:
 
1055
                self.assertRaises(errors.ConnectionTimeout,
 
1056
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1057
            os.close(w_client)
 
1058
            data = server.read_bytes(5)
 
1059
            self.assertEqual('', data)
 
1060
 
 
1061
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1062
        server, _ = self.create_pipe_context('', None)
 
1063
        # Our file doesn't support polling, so we should always just return
 
1064
        # 'you have data to consume.
 
1065
        server._wait_for_bytes_with_timeout(0.01)
 
1066
 
936
1067
 
937
1068
class TestGetProtocolFactoryForBytes(tests.TestCase):
938
1069
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
968
1099
 
969
1100
class TestSmartTCPServer(tests.TestCase):
970
1101
 
 
1102
    def make_server(self):
 
1103
        """Create a SmartTCPServer that we can exercise.
 
1104
 
 
1105
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1106
        version overrides lots of functionality like 'serve', and we want to
 
1107
        test the raw service.
 
1108
 
 
1109
        This will start the server in another thread, and wait for it to
 
1110
        indicate it has finished starting up.
 
1111
 
 
1112
        :return: (server, server_thread)
 
1113
        """
 
1114
        t = _mod_transport.get_transport_from_url('memory:///')
 
1115
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1116
        server._ACCEPT_TIMEOUT = 0.1
 
1117
        # We don't use 'localhost' because that might be an IPv6 address.
 
1118
        server.start_server('127.0.0.1', 0)
 
1119
        server_thread = threading.Thread(target=server.serve,
 
1120
                                         args=(self.id(),))
 
1121
        server_thread.start()
 
1122
        # Ensure this gets called at some point
 
1123
        self.addCleanup(server._stop_gracefully)
 
1124
        server._started.wait()
 
1125
        return server, server_thread
 
1126
 
 
1127
    def ensure_client_disconnected(self, client_sock):
 
1128
        """Ensure that a socket is closed, discarding all errors."""
 
1129
        try:
 
1130
            client_sock.close()
 
1131
        except Exception:
 
1132
            pass
 
1133
 
 
1134
    def connect_to_server(self, server):
 
1135
        """Create a client socket that can talk to the server."""
 
1136
        client_sock = socket.socket()
 
1137
        server_info = server._server_socket.getsockname()
 
1138
        client_sock.connect(server_info)
 
1139
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1140
        return client_sock
 
1141
 
 
1142
    def connect_to_server_and_hangup(self, server):
 
1143
        """Connect to the server, and then hang up.
 
1144
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1145
        """
 
1146
        # If the server has already signaled that the socket is closed, we
 
1147
        # don't need to try to connect to it. Not being set, though, the server
 
1148
        # might still close the socket while we try to connect to it. So we
 
1149
        # still have to catch the exception.
 
1150
        if server._stopped.isSet():
 
1151
            return
 
1152
        try:
 
1153
            client_sock = self.connect_to_server(server)
 
1154
            client_sock.close()
 
1155
        except socket.error, e:
 
1156
            # If the server has hung up already, that is fine.
 
1157
            pass
 
1158
 
 
1159
    def say_hello(self, client_sock):
 
1160
        """Send the 'hello' smart RPC, and expect the response."""
 
1161
        client_sock.send('hello\n')
 
1162
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1163
 
 
1164
    def shutdown_server_cleanly(self, server, server_thread):
 
1165
        server._stop_gracefully()
 
1166
        self.connect_to_server_and_hangup(server)
 
1167
        server._stopped.wait()
 
1168
        server._fully_stopped.wait()
 
1169
        server_thread.join()
 
1170
 
971
1171
    def test_get_error_unexpected(self):
972
1172
        """Error reported by server with no specific representation"""
973
1173
        self.overrideEnv('BZR_NO_SMART_VFS', None)
991
1191
                                t.get, 'something')
992
1192
        self.assertContainsRe(str(err), 'some random exception')
993
1193
 
 
1194
    def test_propagates_timeout(self):
 
1195
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1196
        server_sock, client_sock = portable_socket_pair()
 
1197
        handler = server._make_handler(server_sock)
 
1198
        self.assertEqual(1.23, handler._client_timeout)
 
1199
 
 
1200
    def test_serve_conn_tracks_connections(self):
 
1201
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1202
        server_sock, client_sock = portable_socket_pair()
 
1203
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1204
        self.assertEqual(1, len(server._active_connections))
 
1205
        # We still want to talk on the connection. Polling should indicate it
 
1206
        # is still active.
 
1207
        server._poll_active_connections()
 
1208
        self.assertEqual(1, len(server._active_connections))
 
1209
        # Closing the socket will end the active thread, and polling will
 
1210
        # notice and remove it from the active set.
 
1211
        client_sock.close()
 
1212
        server._poll_active_connections(0.1)
 
1213
        self.assertEqual(0, len(server._active_connections))
 
1214
 
 
1215
    def test_serve_closes_out_finished_connections(self):
 
1216
        server, server_thread = self.make_server()
 
1217
        # The server is started, connect to it.
 
1218
        client_sock = self.connect_to_server(server)
 
1219
        # We send and receive on the connection, so that we know the
 
1220
        # server-side has seen the connect, and started handling the
 
1221
        # results.
 
1222
        self.say_hello(client_sock)
 
1223
        self.assertEqual(1, len(server._active_connections))
 
1224
        # Grab a handle to the thread that is processing our request
 
1225
        _, server_side_thread = server._active_connections[0]
 
1226
        # Close the connection, ask the server to stop, and wait for the
 
1227
        # server to stop, as well as the thread that was servicing the
 
1228
        # client request.
 
1229
        client_sock.close()
 
1230
        # Wait for the server-side request thread to notice we are closed.
 
1231
        server_side_thread.join()
 
1232
        # Stop the server, it should notice the connection has finished.
 
1233
        self.shutdown_server_cleanly(server, server_thread)
 
1234
        # The server should have noticed that all clients are gone before
 
1235
        # exiting.
 
1236
        self.assertEqual(0, len(server._active_connections))
 
1237
 
 
1238
    def test_serve_reaps_finished_connections(self):
 
1239
        server, server_thread = self.make_server()
 
1240
        client_sock1 = self.connect_to_server(server)
 
1241
        # We send and receive on the connection, so that we know the
 
1242
        # server-side has seen the connect, and started handling the
 
1243
        # results.
 
1244
        self.say_hello(client_sock1)
 
1245
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1246
        client_sock1.close()
 
1247
        server_side_thread1.join()
 
1248
        # By waiting until the first connection is fully done, the server
 
1249
        # should notice after another connection that the first has finished.
 
1250
        client_sock2 = self.connect_to_server(server)
 
1251
        self.say_hello(client_sock2)
 
1252
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1253
        # There is a race condition. We know that client_sock2 has been
 
1254
        # registered, but not that _poll_active_connections has been called. We
 
1255
        # know that it will be called before the server will accept a new
 
1256
        # connection, however. So connect one more time, and assert that we
 
1257
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1258
        # connection is not connection 1
 
1259
        client_sock3 = self.connect_to_server(server)
 
1260
        self.say_hello(client_sock3)
 
1261
        # Copy the list, so we don't have it mutating behind our back
 
1262
        conns = list(server._active_connections)
 
1263
        self.assertEqual(2, len(conns))
 
1264
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1265
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1266
        client_sock2.close()
 
1267
        client_sock3.close()
 
1268
        self.shutdown_server_cleanly(server, server_thread)
 
1269
 
 
1270
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1271
        server, server_thread = self.make_server()
 
1272
        # We need something big enough that it won't fit in a single recv. So
 
1273
        # the server thread gets blocked writing content to the client until we
 
1274
        # finish reading on the client.
 
1275
        server.backing_transport.put_bytes('bigfile',
 
1276
            'a'*1024*1024)
 
1277
        client_sock = self.connect_to_server(server)
 
1278
        self.say_hello(client_sock)
 
1279
        _, server_side_thread = server._active_connections[0]
 
1280
        # Start the RPC, but don't finish reading the response
 
1281
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1282
            'base', client_sock)
 
1283
        client_client = client._SmartClient(client_medium)
 
1284
        resp, response_handler = client_client.call_expecting_body('get',
 
1285
            'bigfile')
 
1286
        self.assertEqual(('ok',), resp)
 
1287
        # Ask the server to stop gracefully, and wait for it.
 
1288
        server._stop_gracefully()
 
1289
        self.connect_to_server_and_hangup(server)
 
1290
        server._stopped.wait()
 
1291
        # It should not be accepting another connection.
 
1292
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1293
        # It should also not be fully stopped
 
1294
        server._fully_stopped.wait(0.01)
 
1295
        self.assertFalse(server._fully_stopped.isSet())
 
1296
        response_handler.read_body_bytes()
 
1297
        client_sock.close()
 
1298
        server_side_thread.join()
 
1299
        server_thread.join()
 
1300
        self.assertTrue(server._fully_stopped.isSet())
 
1301
        log = self.get_log()
 
1302
        self.assertThat(log, DocTestMatches("""\
 
1303
    INFO  Requested to stop gracefully
 
1304
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1305
    INFO  Waiting for 1 client(s) to finish
 
1306
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1307
 
 
1308
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1309
        server, server_thread = self.make_server()
 
1310
        client_sock = self.connect_to_server(server)
 
1311
        self.say_hello(client_sock)
 
1312
        server_handler, server_side_thread = server._active_connections[0]
 
1313
        self.assertFalse(server_handler.finished)
 
1314
        server._stop_gracefully()
 
1315
        self.assertTrue(server_handler.finished)
 
1316
        client_sock.close()
 
1317
        self.connect_to_server_and_hangup(server)
 
1318
        server_thread.join()
 
1319
 
994
1320
 
995
1321
class SmartTCPTests(tests.TestCase):
996
1322
    """Tests for connection/end to end behaviour using the TCP server.
1014
1340
            mem_server.start_server()
1015
1341
            self.addCleanup(mem_server.stop_server)
1016
1342
            self.permit_url(mem_server.get_url())
1017
 
            self.backing_transport = transport.get_transport(
 
1343
            self.backing_transport = _mod_transport.get_transport_from_url(
1018
1344
                mem_server.get_url())
1019
1345
        else:
1020
1346
            self.backing_transport = backing_transport
1021
1347
        if readonly:
1022
1348
            self.real_backing_transport = self.backing_transport
1023
 
            self.backing_transport = transport.get_transport(
 
1349
            self.backing_transport = _mod_transport.get_transport_from_url(
1024
1350
                "readonly+" + self.backing_transport.abspath('.'))
1025
 
        self.server = server.SmartTCPServer(self.backing_transport)
 
1351
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1352
                                                 client_timeout=4.0)
1026
1353
        self.server.start_server('127.0.0.1', 0)
1027
1354
        self.server.start_background_thread('-' + self.id())
1028
1355
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1162
1489
    def test_server_started_hook_memory(self):
1163
1490
        """The server_started hook fires when the server is started."""
1164
1491
        self.hook_calls = []
1165
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1492
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1166
1493
            self.capture_server_call, None)
1167
1494
        self.start_server()
1168
1495
        # at this point, the server will be starting a thread up.
1176
1503
    def test_server_started_hook_file(self):
1177
1504
        """The server_started hook fires when the server is started."""
1178
1505
        self.hook_calls = []
1179
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1506
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1180
1507
            self.capture_server_call, None)
1181
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1508
        self.start_server(
 
1509
            backing_transport=_mod_transport.get_transport_from_path("."))
1182
1510
        # at this point, the server will be starting a thread up.
1183
1511
        # there is no indicator at the moment, so bodge it by doing a request.
1184
1512
        self.transport.has('.')
1192
1520
    def test_server_stopped_hook_simple_memory(self):
1193
1521
        """The server_stopped hook fires when the server is stopped."""
1194
1522
        self.hook_calls = []
1195
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1523
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1196
1524
            self.capture_server_call, None)
1197
1525
        self.start_server()
1198
1526
        result = [([self.backing_transport.base], self.transport.base)]
1209
1537
    def test_server_stopped_hook_simple_file(self):
1210
1538
        """The server_stopped hook fires when the server is stopped."""
1211
1539
        self.hook_calls = []
1212
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1540
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1213
1541
            self.capture_server_call, None)
1214
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1542
        self.start_server(
 
1543
            backing_transport=_mod_transport.get_transport_from_path("."))
1215
1544
        result = [(
1216
1545
            [self.backing_transport.base, self.backing_transport.external_url()]
1217
1546
            , self.transport.base)]
1353
1682
class RemoteTransportRegistration(tests.TestCase):
1354
1683
 
1355
1684
    def test_registration(self):
1356
 
        t = transport.get_transport('bzr+ssh://example.com/path')
 
1685
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
1357
1686
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1358
 
        self.assertEqual('example.com', t._host)
 
1687
        self.assertEqual('example.com', t._parsed_url.host)
1359
1688
 
1360
1689
    def test_bzr_https(self):
1361
1690
        # https://bugs.launchpad.net/bzr/+bug/128456
1362
 
        t = transport.get_transport('bzr+https://example.com/path')
 
1691
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
1363
1692
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1364
1693
        self.assertStartsWith(
1365
1694
            t._http_transport.base,
2542
2871
        from_server = StringIO()
2543
2872
        transport = memory.MemoryTransport('memory:///')
2544
2873
        server = medium.SmartServerPipeStreamMedium(
2545
 
            to_server, from_server, transport)
 
2874
            to_server, from_server, transport, timeout=4.0)
2546
2875
        proto = server._build_protocol()
2547
2876
        message_handler = proto.message_handler
2548
2877
        server._serve_one_request(proto)
3557
3886
        # still work correctly.
3558
3887
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3559
3888
        new_transport = base_transport.clone('c')
3560
 
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
 
3889
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
3561
3890
        self.assertEqual(
3562
3891
            'c/',
3563
3892
            new_transport._client.remote_path_from_transport(new_transport))
3580
3909
        r = t._redirected_to('http://www.example.com/foo',
3581
3910
                             'http://www.example.com/bar')
3582
3911
        self.assertEquals(type(r), type(t))
3583
 
        self.assertEquals('joe', t._user)
3584
 
        self.assertEquals(t._user, r._user)
 
3912
        self.assertEquals('joe', t._parsed_url.user)
 
3913
        self.assertEquals(t._parsed_url.user, r._parsed_url.user)
3585
3914
 
3586
3915
    def test_redirected_to_same_host_different_protocol(self):
3587
3916
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')