/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2018-07-08 15:47:10 UTC
  • mto: This revision was merged to the branch mainline in revision 7036.
  • Revision ID: jelmer@jelmer.uk-20180708154710-zebexq602tcer8hv
Fix more merge tests.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
bzrlib/transport/smart/__init__.py.
 
24
breezy/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
 
30
import io
27
31
import os
28
32
import sys
29
 
import urllib
 
33
import time
30
34
 
31
 
from bzrlib.lazy_import import lazy_import
 
35
import breezy
 
36
from ...lazy_import import lazy_import
32
37
lazy_import(globals(), """
33
 
import atexit
 
38
import select
34
39
import socket
35
40
import thread
36
41
import weakref
37
42
 
38
 
from bzrlib import (
 
43
from breezy import (
39
44
    debug,
40
 
    errors,
41
 
    symbol_versioning,
42
45
    trace,
 
46
    transport,
43
47
    ui,
44
48
    urlutils,
45
49
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
47
 
from bzrlib.transport import ssh
 
50
from breezy.i18n import gettext
 
51
from breezy.bzr.smart import client, protocol, request, signals, vfs
 
52
from breezy.transport import ssh
48
53
""")
49
 
from bzrlib import osutils
 
54
from ... import (
 
55
    errors,
 
56
    osutils,
 
57
    )
50
58
 
51
59
# Throughout this module buffer size parameters are either limited to be at
52
60
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
54
62
# from non-sockets as well.
55
63
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
64
 
 
65
 
 
66
class HpssVfsRequestNotAllowed(errors.BzrError):
 
67
 
 
68
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
 
69
            "%(method)s, %(arguments)s.")
 
70
 
 
71
    def __init__(self, method, arguments):
 
72
        self.method = method
 
73
        self.arguments = arguments
 
74
 
 
75
 
57
76
def _get_protocol_factory_for_bytes(bytes):
58
77
    """Determine the right protocol factory for 'bytes'.
59
78
 
95
114
    :returns: a tuple of two strs: (line, excess)
96
115
    """
97
116
    newline_pos = -1
98
 
    bytes = ''
 
117
    bytes = b''
99
118
    while newline_pos == -1:
100
119
        new_bytes = read_bytes_func(1)
101
120
        bytes += new_bytes
102
 
        if new_bytes == '':
 
121
        if new_bytes == b'':
103
122
            # Ran out of bytes before receiving a complete line.
104
 
            return bytes, ''
105
 
        newline_pos = bytes.find('\n')
 
123
            return bytes, b''
 
124
        newline_pos = bytes.find(b'\n')
106
125
    line = bytes[:newline_pos+1]
107
126
    excess = bytes[newline_pos+1:]
108
127
    return line, excess
114
133
    def __init__(self):
115
134
        self._push_back_buffer = None
116
135
 
117
 
    def _push_back(self, bytes):
 
136
    def _push_back(self, data):
118
137
        """Return unused bytes to the medium, because they belong to the next
119
138
        request(s).
120
139
 
121
140
        This sets the _push_back_buffer to the given bytes.
122
141
        """
 
142
        if not isinstance(data, bytes):
 
143
            raise TypeError(data)
123
144
        if self._push_back_buffer is not None:
124
145
            raise AssertionError(
125
146
                "_push_back called when self._push_back_buffer is %r"
126
147
                % (self._push_back_buffer,))
127
 
        if bytes == '':
 
148
        if data == b'':
128
149
            return
129
 
        self._push_back_buffer = bytes
 
150
        self._push_back_buffer = data
130
151
 
131
152
    def _get_push_back_buffer(self):
132
 
        if self._push_back_buffer == '':
 
153
        if self._push_back_buffer == b'':
133
154
            raise AssertionError(
134
155
                '%s._push_back_buffer should never be the empty string, '
135
156
                'which can be confused with EOF' % (self,))
176
197
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
198
 
178
199
 
 
200
_bad_file_descriptor = (errno.EBADF,)
 
201
if sys.platform == 'win32':
 
202
    # Given on Windows if you pass a closed socket to select.select. Probably
 
203
    # also given if you pass a file handle to select.
 
204
    WSAENOTSOCK = 10038
 
205
    _bad_file_descriptor += (WSAENOTSOCK,)
 
206
 
 
207
 
179
208
class SmartServerStreamMedium(SmartMedium):
180
209
    """Handles smart commands coming over a stream.
181
210
 
194
223
        the stream.  See also the _push_back method.
195
224
    """
196
225
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
226
    _timer = time.time
 
227
 
 
228
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
229
        """Construct new server.
199
230
 
200
231
        :param backing_transport: Transport for the directory served.
203
234
        self.backing_transport = backing_transport
204
235
        self.root_client_path = root_client_path
205
236
        self.finished = False
 
237
        if timeout is None:
 
238
            raise AssertionError('You must supply a timeout.')
 
239
        self._client_timeout = timeout
 
240
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
241
        SmartMedium.__init__(self)
207
242
 
208
243
    def serve(self):
214
249
            while not self.finished:
215
250
                server_protocol = self._build_protocol()
216
251
                self._serve_one_request(server_protocol)
217
 
        except Exception, e:
 
252
        except errors.ConnectionTimeout as e:
 
253
            trace.note('%s' % (e,))
 
254
            trace.log_exception_quietly()
 
255
            self._disconnect_client()
 
256
            # We reported it, no reason to make a big fuss.
 
257
            return
 
258
        except Exception as e:
218
259
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
260
            raise
 
261
        self._disconnect_client()
 
262
 
 
263
    def _stop_gracefully(self):
 
264
        """When we finish this message, stop looking for more."""
 
265
        trace.mutter('Stopping %s' % (self,))
 
266
        self.finished = True
 
267
 
 
268
    def _disconnect_client(self):
 
269
        """Close the current connection. We stopped due to a timeout/etc."""
 
270
        # The default implementation is a no-op, because that is all we used to
 
271
        # do when disconnecting from a client. I suppose we never had the
 
272
        # *server* initiate a disconnect, before
 
273
 
 
274
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
275
        """Wait for more bytes to be read, but timeout if none available.
 
276
 
 
277
        This allows us to detect idle connections, and stop trying to read from
 
278
        them, without setting the socket itself to non-blocking. This also
 
279
        allows us to specify when we watch for idle timeouts.
 
280
 
 
281
        :return: Did we timeout? (True if we timed out, False if there is data
 
282
            to be read)
 
283
        """
 
284
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
285
 
221
286
    def _build_protocol(self):
222
287
        """Identifies the version of the incoming request, and returns an
227
292
 
228
293
        :returns: a SmartServerRequestProtocol.
229
294
        """
 
295
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
296
        if self.finished:
 
297
            # We're stopping, so don't try to do any more work
 
298
            return None
230
299
        bytes = self._get_line()
231
300
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
301
        protocol = protocol_factory(
234
303
        protocol.accept_bytes(unused_bytes)
235
304
        return protocol
236
305
 
 
306
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
307
        """select() on a file descriptor, waiting for nonblocking read()
 
308
 
 
309
        This will raise a ConnectionTimeout exception if we do not get a
 
310
        readable handle before timeout_seconds.
 
311
        :return: None
 
312
        """
 
313
        t_end = self._timer() + timeout_seconds
 
314
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
315
        rs = xs = None
 
316
        while not rs and not xs and self._timer() < t_end:
 
317
            if self.finished:
 
318
                return
 
319
            try:
 
320
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
321
            except (select.error, socket.error) as e:
 
322
                err = getattr(e, 'errno', None)
 
323
                if err is None and getattr(e, 'args', None) is not None:
 
324
                    # select.error doesn't have 'errno', it just has args[0]
 
325
                    err = e.args[0]
 
326
                if err in _bad_file_descriptor:
 
327
                    return # Not a socket indicates read() will fail
 
328
                elif err == errno.EINTR:
 
329
                    # Interrupted, keep looping.
 
330
                    continue
 
331
                raise
 
332
            except ValueError:
 
333
                return  # Socket may already be closed
 
334
        if rs or xs:
 
335
            return
 
336
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
337
                                       % (timeout_seconds,))
 
338
 
237
339
    def _serve_one_request(self, protocol):
238
340
        """Read one request from input, process, send back a response.
239
341
 
240
342
        :param protocol: a SmartServerRequestProtocol.
241
343
        """
 
344
        if protocol is None:
 
345
            return
242
346
        try:
243
347
            self._serve_one_request_unguarded(protocol)
244
348
        except KeyboardInterrupt:
245
349
            raise
246
 
        except Exception, e:
 
350
        except Exception as e:
247
351
            self.terminate_due_to_error()
248
352
 
249
353
    def terminate_due_to_error(self):
260
364
 
261
365
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
366
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
367
    def __init__(self, sock, backing_transport, root_client_path='/',
 
368
                 timeout=None):
264
369
        """Constructor.
265
370
 
266
371
        :param sock: the socket the server will read from.  It will be put
267
372
            into blocking mode.
268
373
        """
269
374
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
375
            self, backing_transport, root_client_path=root_client_path,
 
376
            timeout=timeout)
271
377
        sock.setblocking(True)
272
378
        self.socket = sock
 
379
        # Get the getpeername now, as we might be closed later when we care.
 
380
        try:
 
381
            self._client_info = sock.getpeername()
 
382
        except socket.error:
 
383
            self._client_info = '<unknown>'
 
384
 
 
385
    def __str__(self):
 
386
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
387
 
 
388
    def __repr__(self):
 
389
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
390
            self._client_info)
273
391
 
274
392
    def _serve_one_request_unguarded(self, protocol):
275
393
        while protocol.next_read_size():
277
395
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
278
396
            # short read immediately rather than block.
279
397
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
280
 
            if bytes == '':
 
398
            if bytes == b'':
281
399
                self.finished = True
282
400
                return
283
401
            protocol.accept_bytes(bytes)
284
402
 
285
403
        self._push_back(protocol.unused_data)
286
404
 
 
405
    def _disconnect_client(self):
 
406
        """Close the current connection. We stopped due to a timeout/etc."""
 
407
        self.socket.close()
 
408
 
 
409
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
410
        """Wait for more bytes to be read, but timeout if none available.
 
411
 
 
412
        This allows us to detect idle connections, and stop trying to read from
 
413
        them, without setting the socket itself to non-blocking. This also
 
414
        allows us to specify when we watch for idle timeouts.
 
415
 
 
416
        :return: None, this will raise ConnectionTimeout if we time out before
 
417
            data is available.
 
418
        """
 
419
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
420
 
287
421
    def _read_bytes(self, desired_count):
288
422
        return osutils.read_bytes_from_socket(
289
423
            self.socket, self._report_activity)
306
440
 
307
441
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
442
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
443
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
444
        """Construct new server.
311
445
 
312
446
        :param in_file: Python file from which requests can be read.
313
447
        :param out_file: Python file to write responses.
314
448
        :param backing_transport: Transport for the directory served.
315
449
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
450
        SmartServerStreamMedium.__init__(self, backing_transport,
 
451
            timeout=timeout)
317
452
        if sys.platform == 'win32':
318
453
            # force binary mode for files
319
454
            import msvcrt
324
459
        self._in = in_file
325
460
        self._out = out_file
326
461
 
 
462
    def serve(self):
 
463
        """See SmartServerStreamMedium.serve"""
 
464
        # This is the regular serve, except it adds signal trapping for soft
 
465
        # shutdown.
 
466
        stop_gracefully = self._stop_gracefully
 
467
        signals.register_on_hangup(id(self), stop_gracefully)
 
468
        try:
 
469
            return super(SmartServerPipeStreamMedium, self).serve()
 
470
        finally:
 
471
            signals.unregister_on_hangup(id(self))
 
472
 
327
473
    def _serve_one_request_unguarded(self, protocol):
328
474
        while True:
329
475
            # We need to be careful not to read past the end of the current
335
481
                self._out.flush()
336
482
                return
337
483
            bytes = self.read_bytes(bytes_to_read)
338
 
            if bytes == '':
 
484
            if bytes == b'':
339
485
                # Connection has been closed.
340
486
                self.finished = True
341
487
                self._out.flush()
342
488
                return
343
489
            protocol.accept_bytes(bytes)
344
490
 
 
491
    def _disconnect_client(self):
 
492
        self._in.close()
 
493
        self._out.flush()
 
494
        self._out.close()
 
495
 
 
496
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
497
        """Wait for more bytes to be read, but timeout if none available.
 
498
 
 
499
        This allows us to detect idle connections, and stop trying to read from
 
500
        them, without setting the socket itself to non-blocking. This also
 
501
        allows us to specify when we watch for idle timeouts.
 
502
 
 
503
        :return: None, this will raise ConnectionTimeout if we time out before
 
504
            data is available.
 
505
        """
 
506
        if (getattr(self._in, 'fileno', None) is None
 
507
            or sys.platform == 'win32'):
 
508
            # You can't select() file descriptors on Windows.
 
509
            return
 
510
        try:
 
511
            return self._wait_on_descriptor(self._in, timeout_seconds)
 
512
        except io.UnsupportedOperation:
 
513
            return
 
514
 
345
515
    def _read_bytes(self, desired_count):
346
516
        return self._in.read(desired_count)
347
517
 
475
645
 
476
646
    def read_line(self):
477
647
        line = self._read_line()
478
 
        if not line.endswith('\n'):
 
648
        if not line.endswith(b'\n'):
479
649
            # end of file encountered reading from server
480
650
            raise errors.ConnectionReset(
481
651
                "Unexpected end of message. Please check connectivity "
491
661
        return self._medium._get_line()
492
662
 
493
663
 
 
664
class _VfsRefuser(object):
 
665
    """An object that refuses all VFS requests.
 
666
 
 
667
    """
 
668
 
 
669
    def __init__(self):
 
670
        client._SmartClient.hooks.install_named_hook(
 
671
            'call', self.check_vfs, 'vfs refuser')
 
672
 
 
673
    def check_vfs(self, params):
 
674
        try:
 
675
            request_method = request.request_handlers.get(params.method)
 
676
        except KeyError:
 
677
            # A method we don't know about doesn't count as a VFS method.
 
678
            return
 
679
        if issubclass(request_method, vfs.VfsRequest):
 
680
            raise HpssVfsRequestNotAllowed(params.method, params.args)
 
681
 
 
682
 
494
683
class _DebugCounter(object):
495
684
    """An object that counts the HPSS calls made to each client medium.
496
685
 
497
 
    When a medium is garbage-collected, or failing that when atexit functions
498
 
    are run, the total number of calls made on that medium are reported via
499
 
    trace.note.
 
686
    When a medium is garbage-collected, or failing that when
 
687
    breezy.global_state exits, the total number of calls made on that medium
 
688
    are reported via trace.note.
500
689
    """
501
690
 
502
691
    def __init__(self):
503
692
        self.counts = weakref.WeakKeyDictionary()
504
693
        client._SmartClient.hooks.install_named_hook(
505
694
            'call', self.increment_call_count, 'hpss call counter')
506
 
        atexit.register(self.flush_all)
 
695
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
507
696
 
508
697
    def track(self, medium):
509
698
        """Start tracking calls made to a medium.
543
732
        value['count'] = 0
544
733
        value['vfs_count'] = 0
545
734
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
735
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
736
                       count, vfs_count, medium_repr))
548
737
 
549
738
    def flush_all(self):
550
739
        for ref in list(self.counts.keys()):
551
740
            self.done(ref)
552
741
 
553
742
_debug_counter = None
 
743
_vfs_refuser = None
554
744
 
555
745
 
556
746
class SmartClientMedium(SmartMedium):
573
763
            if _debug_counter is None:
574
764
                _debug_counter = _DebugCounter()
575
765
            _debug_counter.track(self)
 
766
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
767
            global _vfs_refuser
 
768
            if _vfs_refuser is None:
 
769
                _vfs_refuser = _VfsRefuser()
576
770
 
577
771
    def _is_remote_before(self, version_tuple):
578
772
        """Is it possible the remote side supports RPCs for a given version?
631
825
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
632
826
                client_protocol.query_version()
633
827
                self._done_hello = True
634
 
            except errors.SmartProtocolError, e:
 
828
            except errors.SmartProtocolError as e:
635
829
                # Cache the error, just like we would cache a successful
636
830
                # result.
637
831
                self._protocol_version_error = e
669
863
        """
670
864
        medium_base = urlutils.join(self.base, '/')
671
865
        rel_url = urlutils.relative_url(medium_base, transport.base)
672
 
        return urllib.unquote(rel_url)
 
866
        return urlutils.unquote(rel_url)
673
867
 
674
868
 
675
869
class SmartClientStreamMedium(SmartClientMedium):
710
904
        """
711
905
        return SmartClientStreamMediumRequest(self)
712
906
 
 
907
    def reset(self):
 
908
        """We have been disconnected, reset current state.
 
909
 
 
910
        This resets things like _current_request and connected state.
 
911
        """
 
912
        self.disconnect()
 
913
        self._current_request = None
 
914
 
713
915
 
714
916
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
917
    """A client medium using simple pipes.
716
918
 
717
919
    This client does not manage the pipes: it assumes they will always be open.
718
 
 
719
 
    Note that if readable_pipe.read might raise IOError or OSError with errno
720
 
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
721
 
    (such as used for sys.stdin) are safe.
722
920
    """
723
921
 
724
922
    def __init__(self, readable_pipe, writeable_pipe, base):
726
924
        self._readable_pipe = readable_pipe
727
925
        self._writeable_pipe = writeable_pipe
728
926
 
729
 
    def _accept_bytes(self, bytes):
 
927
    def _accept_bytes(self, data):
730
928
        """See SmartClientStreamMedium.accept_bytes."""
731
 
        self._writeable_pipe.write(bytes)
732
 
        self._report_activity(len(bytes), 'write')
 
929
        try:
 
930
            self._writeable_pipe.write(data)
 
931
        except IOError as e:
 
932
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
933
                raise errors.ConnectionReset(
 
934
                    "Error trying to write to subprocess", e)
 
935
            raise
 
936
        self._report_activity(len(data), 'write')
733
937
 
734
938
    def _flush(self):
735
939
        """See SmartClientStreamMedium._flush()."""
 
940
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
941
        #       However, testing shows that even when the child process is
 
942
        #       gone, this doesn't error.
736
943
        self._writeable_pipe.flush()
737
944
 
738
945
    def _read_bytes(self, count):
739
946
        """See SmartClientStreamMedium._read_bytes."""
740
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
741
 
        self._report_activity(len(bytes), 'read')
742
 
        return bytes
 
947
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
948
        data = self._readable_pipe.read(bytes_to_read)
 
949
        self._report_activity(len(data), 'read')
 
950
        return data
 
951
 
 
952
 
 
953
class SSHParams(object):
 
954
    """A set of parameters for starting a remote bzr via SSH."""
 
955
 
 
956
    def __init__(self, host, port=None, username=None, password=None,
 
957
            bzr_remote_path='bzr'):
 
958
        self.host = host
 
959
        self.port = port
 
960
        self.username = username
 
961
        self.password = password
 
962
        self.bzr_remote_path = bzr_remote_path
743
963
 
744
964
 
745
965
class SmartSSHClientMedium(SmartClientStreamMedium):
746
 
    """A client medium using SSH."""
747
 
 
748
 
    def __init__(self, host, port=None, username=None, password=None,
749
 
            base=None, vendor=None, bzr_remote_path=None):
 
966
    """A client medium using SSH.
 
967
 
 
968
    It delegates IO to a SmartSimplePipesClientMedium or
 
969
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
970
    """
 
971
 
 
972
    def __init__(self, base, ssh_params, vendor=None):
750
973
        """Creates a client that will connect on the first use.
751
974
 
 
975
        :param ssh_params: A SSHParams instance.
752
976
        :param vendor: An optional override for the ssh vendor to use. See
753
 
            bzrlib.transport.ssh for details on ssh vendors.
 
977
            breezy.transport.ssh for details on ssh vendors.
754
978
        """
755
 
        self._connected = False
756
 
        self._host = host
757
 
        self._password = password
758
 
        self._port = port
759
 
        self._username = username
 
979
        self._real_medium = None
 
980
        self._ssh_params = ssh_params
760
981
        # for the benefit of progress making a short description of this
761
982
        # transport
762
983
        self._scheme = 'bzr+ssh'
764
985
        # _DebugCounter so we have to store all the values used in our repr
765
986
        # method before calling the super init.
766
987
        SmartClientStreamMedium.__init__(self, base)
767
 
        self._read_from = None
 
988
        self._vendor = vendor
768
989
        self._ssh_connection = None
769
 
        self._vendor = vendor
770
 
        self._write_to = None
771
 
        self._bzr_remote_path = bzr_remote_path
772
990
 
773
991
    def __repr__(self):
774
 
        if self._port is None:
 
992
        if self._ssh_params.port is None:
775
993
            maybe_port = ''
776
994
        else:
777
 
            maybe_port = ':%s' % self._port
778
 
        return "%s(%s://%s@%s%s/)" % (
 
995
            maybe_port = ':%s' % self._ssh_params.port
 
996
        if self._ssh_params.username is None:
 
997
            maybe_user = ''
 
998
        else:
 
999
            maybe_user = '%s@' % self._ssh_params.username
 
1000
        return "%s(%s://%s%s%s/)" % (
779
1001
            self.__class__.__name__,
780
1002
            self._scheme,
781
 
            self._username,
782
 
            self._host,
 
1003
            maybe_user,
 
1004
            self._ssh_params.host,
783
1005
            maybe_port)
784
1006
 
785
1007
    def _accept_bytes(self, bytes):
786
1008
        """See SmartClientStreamMedium.accept_bytes."""
787
1009
        self._ensure_connection()
788
 
        self._write_to.write(bytes)
789
 
        self._report_activity(len(bytes), 'write')
 
1010
        self._real_medium.accept_bytes(bytes)
790
1011
 
791
1012
    def disconnect(self):
792
1013
        """See SmartClientMedium.disconnect()."""
793
 
        if not self._connected:
794
 
            return
795
 
        self._read_from.close()
796
 
        self._write_to.close()
797
 
        self._ssh_connection.close()
798
 
        self._connected = False
 
1014
        if self._real_medium is not None:
 
1015
            self._real_medium.disconnect()
 
1016
            self._real_medium = None
 
1017
        if self._ssh_connection is not None:
 
1018
            self._ssh_connection.close()
 
1019
            self._ssh_connection = None
799
1020
 
800
1021
    def _ensure_connection(self):
801
1022
        """Connect this medium if not already connected."""
802
 
        if self._connected:
 
1023
        if self._real_medium is not None:
803
1024
            return
804
1025
        if self._vendor is None:
805
1026
            vendor = ssh._get_ssh_vendor()
806
1027
        else:
807
1028
            vendor = self._vendor
808
 
        self._ssh_connection = vendor.connect_ssh(self._username,
809
 
                self._password, self._host, self._port,
810
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
1029
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1030
                self._ssh_params.password, self._ssh_params.host,
 
1031
                self._ssh_params.port,
 
1032
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
811
1033
                         '--directory=/', '--allow-writes'])
812
 
        self._read_from, self._write_to = \
813
 
            self._ssh_connection.get_filelike_channels()
814
 
        self._connected = True
 
1034
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1035
        if io_kind == 'socket':
 
1036
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1037
                self.base, io_object)
 
1038
        elif io_kind == 'pipes':
 
1039
            read_from, write_to = io_object
 
1040
            self._real_medium = SmartSimplePipesClientMedium(
 
1041
                read_from, write_to, self.base)
 
1042
        else:
 
1043
            raise AssertionError(
 
1044
                "Unexpected io_kind %r from %r"
 
1045
                % (io_kind, self._ssh_connection))
 
1046
        for hook in transport.Transport.hooks["post_connect"]:
 
1047
            hook(self)
815
1048
 
816
1049
    def _flush(self):
817
1050
        """See SmartClientStreamMedium._flush()."""
818
 
        self._write_to.flush()
 
1051
        self._real_medium._flush()
819
1052
 
820
1053
    def _read_bytes(self, count):
821
1054
        """See SmartClientStreamMedium.read_bytes."""
822
 
        if not self._connected:
 
1055
        if self._real_medium is None:
823
1056
            raise errors.MediumNotConnected(self)
824
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
825
 
        bytes = self._read_from.read(bytes_to_read)
826
 
        self._report_activity(len(bytes), 'read')
827
 
        return bytes
 
1057
        return self._real_medium.read_bytes(count)
828
1058
 
829
1059
 
830
1060
# Port 4155 is the default port for bzr://, registered with IANA.
832
1062
BZR_DEFAULT_PORT = 4155
833
1063
 
834
1064
 
835
 
class SmartTCPClientMedium(SmartClientStreamMedium):
836
 
    """A client medium using TCP."""
837
 
 
838
 
    def __init__(self, host, port, base):
839
 
        """Creates a client that will connect on the first use."""
 
1065
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1066
    """A client medium using a socket.
 
1067
 
 
1068
    This class isn't usable directly.  Use one of its subclasses instead.
 
1069
    """
 
1070
 
 
1071
    def __init__(self, base):
840
1072
        SmartClientStreamMedium.__init__(self, base)
 
1073
        self._socket = None
841
1074
        self._connected = False
842
 
        self._host = host
843
 
        self._port = port
844
 
        self._socket = None
845
1075
 
846
1076
    def _accept_bytes(self, bytes):
847
1077
        """See SmartClientMedium.accept_bytes."""
848
1078
        self._ensure_connection()
849
1079
        osutils.send_all(self._socket, bytes, self._report_activity)
850
1080
 
 
1081
    def _ensure_connection(self):
 
1082
        """Connect this medium if not already connected."""
 
1083
        raise NotImplementedError(self._ensure_connection)
 
1084
 
 
1085
    def _flush(self):
 
1086
        """See SmartClientStreamMedium._flush().
 
1087
 
 
1088
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1089
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1090
        future.
 
1091
        """
 
1092
 
 
1093
    def _read_bytes(self, count):
 
1094
        """See SmartClientMedium.read_bytes."""
 
1095
        if not self._connected:
 
1096
            raise errors.MediumNotConnected(self)
 
1097
        return osutils.read_bytes_from_socket(
 
1098
            self._socket, self._report_activity)
 
1099
 
851
1100
    def disconnect(self):
852
1101
        """See SmartClientMedium.disconnect()."""
853
1102
        if not self._connected:
856
1105
        self._socket = None
857
1106
        self._connected = False
858
1107
 
 
1108
 
 
1109
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1110
    """A client medium that creates a TCP connection."""
 
1111
 
 
1112
    def __init__(self, host, port, base):
 
1113
        """Creates a client that will connect on the first use."""
 
1114
        SmartClientSocketMedium.__init__(self, base)
 
1115
        self._host = host
 
1116
        self._port = port
 
1117
 
859
1118
    def _ensure_connection(self):
860
1119
        """Connect this medium if not already connected."""
861
1120
        if self._connected:
867
1126
        try:
868
1127
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
869
1128
                socket.SOCK_STREAM, 0, 0)
870
 
        except socket.gaierror, (err_num, err_msg):
 
1129
        except socket.gaierror as xxx_todo_changeme:
 
1130
            (err_num, err_msg) = xxx_todo_changeme.args
871
1131
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
872
1132
                    (self._host, port, err_msg))
873
1133
        # Initialize err in case there are no addresses returned:
874
 
        err = socket.error("no address found for %s" % self._host)
 
1134
        last_err = socket.error("no address found for %s" % self._host)
875
1135
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
876
1136
            try:
877
1137
                self._socket = socket.socket(family, socktype, proto)
878
1138
                self._socket.setsockopt(socket.IPPROTO_TCP,
879
1139
                                        socket.TCP_NODELAY, 1)
880
1140
                self._socket.connect(sockaddr)
881
 
            except socket.error, err:
 
1141
            except socket.error as err:
882
1142
                if self._socket is not None:
883
1143
                    self._socket.close()
884
1144
                self._socket = None
 
1145
                last_err = err
885
1146
                continue
886
1147
            break
887
1148
        if self._socket is None:
888
1149
            # socket errors either have a (string) or (errno, string) as their
889
1150
            # args.
890
 
            if type(err.args) is str:
891
 
                err_msg = err.args
 
1151
            if isinstance(last_err.args, str):
 
1152
                err_msg = last_err.args
892
1153
            else:
893
 
                err_msg = err.args[1]
 
1154
                err_msg = last_err.args[1]
894
1155
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
895
1156
                    (self._host, port, err_msg))
896
1157
        self._connected = True
897
 
 
898
 
    def _flush(self):
899
 
        """See SmartClientStreamMedium._flush().
900
 
 
901
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
902
 
        add a means to do a flush, but that can be done in the future.
903
 
        """
904
 
 
905
 
    def _read_bytes(self, count):
906
 
        """See SmartClientMedium.read_bytes."""
907
 
        if not self._connected:
908
 
            raise errors.MediumNotConnected(self)
909
 
        return osutils.read_bytes_from_socket(
910
 
            self._socket, self._report_activity)
 
1158
        for hook in transport.Transport.hooks["post_connect"]:
 
1159
            hook(self)
 
1160
 
 
1161
 
 
1162
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1163
    """A client medium for an already connected socket.
 
1164
    
 
1165
    Note that this class will assume it "owns" the socket, so it will close it
 
1166
    when its disconnect method is called.
 
1167
    """
 
1168
 
 
1169
    def __init__(self, base, sock):
 
1170
        SmartClientSocketMedium.__init__(self, base)
 
1171
        self._socket = sock
 
1172
        self._connected = True
 
1173
 
 
1174
    def _ensure_connection(self):
 
1175
        # Already connected, by definition!  So nothing to do.
 
1176
        pass
911
1177
 
912
1178
 
913
1179
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
948
1214
        This invokes self._medium._flush to ensure all bytes are transmitted.
949
1215
        """
950
1216
        self._medium._flush()
951
 
 
952