/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/medium.py

  • Committer: Martin
  • Date: 2018-11-16 19:10:17 UTC
  • mto: This revision was merged to the branch mainline in revision 7177.
  • Revision ID: gzlist@googlemail.com-20181116191017-kyedz1qck0ovon3h
Remove lazy_regexp reset in bt.test_source

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
bzrlib/transport/smart/__init__.py.
 
24
breezy/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
 
30
import io
27
31
import os
28
32
import sys
29
 
import urllib
 
33
import time
30
34
 
31
 
from bzrlib.lazy_import import lazy_import
 
35
import breezy
 
36
from ...lazy_import import lazy_import
32
37
lazy_import(globals(), """
33
 
import atexit
 
38
import select
34
39
import socket
35
40
import thread
36
41
import weakref
37
42
 
38
 
from bzrlib import (
 
43
from breezy import (
39
44
    debug,
40
 
    errors,
41
 
    symbol_versioning,
42
45
    trace,
 
46
    transport,
43
47
    ui,
44
48
    urlutils,
45
49
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
47
 
from bzrlib.transport import ssh
 
50
from breezy.i18n import gettext
 
51
from breezy.bzr.smart import client, protocol, request, signals, vfs
 
52
from breezy.transport import ssh
48
53
""")
49
 
from bzrlib import osutils
 
54
from ... import (
 
55
    errors,
 
56
    osutils,
 
57
    )
50
58
 
51
59
# Throughout this module buffer size parameters are either limited to be at
52
60
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
54
62
# from non-sockets as well.
55
63
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
64
 
 
65
 
 
66
class HpssVfsRequestNotAllowed(errors.BzrError):
 
67
 
 
68
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
 
69
            "%(method)s, %(arguments)s.")
 
70
 
 
71
    def __init__(self, method, arguments):
 
72
        self.method = method
 
73
        self.arguments = arguments
 
74
 
 
75
 
57
76
def _get_protocol_factory_for_bytes(bytes):
58
77
    """Determine the right protocol factory for 'bytes'.
59
78
 
95
114
    :returns: a tuple of two strs: (line, excess)
96
115
    """
97
116
    newline_pos = -1
98
 
    bytes = ''
 
117
    bytes = b''
99
118
    while newline_pos == -1:
100
119
        new_bytes = read_bytes_func(1)
101
120
        bytes += new_bytes
102
 
        if new_bytes == '':
 
121
        if new_bytes == b'':
103
122
            # Ran out of bytes before receiving a complete line.
104
 
            return bytes, ''
105
 
        newline_pos = bytes.find('\n')
106
 
    line = bytes[:newline_pos+1]
107
 
    excess = bytes[newline_pos+1:]
 
123
            return bytes, b''
 
124
        newline_pos = bytes.find(b'\n')
 
125
    line = bytes[:newline_pos + 1]
 
126
    excess = bytes[newline_pos + 1:]
108
127
    return line, excess
109
128
 
110
129
 
114
133
    def __init__(self):
115
134
        self._push_back_buffer = None
116
135
 
117
 
    def _push_back(self, bytes):
 
136
    def _push_back(self, data):
118
137
        """Return unused bytes to the medium, because they belong to the next
119
138
        request(s).
120
139
 
121
140
        This sets the _push_back_buffer to the given bytes.
122
141
        """
 
142
        if not isinstance(data, bytes):
 
143
            raise TypeError(data)
123
144
        if self._push_back_buffer is not None:
124
145
            raise AssertionError(
125
146
                "_push_back called when self._push_back_buffer is %r"
126
147
                % (self._push_back_buffer,))
127
 
        if bytes == '':
 
148
        if data == b'':
128
149
            return
129
 
        self._push_back_buffer = bytes
 
150
        self._push_back_buffer = data
130
151
 
131
152
    def _get_push_back_buffer(self):
132
 
        if self._push_back_buffer == '':
 
153
        if self._push_back_buffer == b'':
133
154
            raise AssertionError(
134
155
                '%s._push_back_buffer should never be the empty string, '
135
156
                'which can be confused with EOF' % (self,))
176
197
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
198
 
178
199
 
 
200
_bad_file_descriptor = (errno.EBADF,)
 
201
if sys.platform == 'win32':
 
202
    # Given on Windows if you pass a closed socket to select.select. Probably
 
203
    # also given if you pass a file handle to select.
 
204
    WSAENOTSOCK = 10038
 
205
    _bad_file_descriptor += (WSAENOTSOCK,)
 
206
 
 
207
 
179
208
class SmartServerStreamMedium(SmartMedium):
180
209
    """Handles smart commands coming over a stream.
181
210
 
194
223
        the stream.  See also the _push_back method.
195
224
    """
196
225
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
226
    _timer = time.time
 
227
 
 
228
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
229
        """Construct new server.
199
230
 
200
231
        :param backing_transport: Transport for the directory served.
203
234
        self.backing_transport = backing_transport
204
235
        self.root_client_path = root_client_path
205
236
        self.finished = False
 
237
        if timeout is None:
 
238
            raise AssertionError('You must supply a timeout.')
 
239
        self._client_timeout = timeout
 
240
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
241
        SmartMedium.__init__(self)
207
242
 
208
243
    def serve(self):
214
249
            while not self.finished:
215
250
                server_protocol = self._build_protocol()
216
251
                self._serve_one_request(server_protocol)
217
 
        except Exception, e:
 
252
        except errors.ConnectionTimeout as e:
 
253
            trace.note('%s' % (e,))
 
254
            trace.log_exception_quietly()
 
255
            self._disconnect_client()
 
256
            # We reported it, no reason to make a big fuss.
 
257
            return
 
258
        except Exception as e:
218
259
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
260
            raise
 
261
        self._disconnect_client()
 
262
 
 
263
    def _stop_gracefully(self):
 
264
        """When we finish this message, stop looking for more."""
 
265
        trace.mutter('Stopping %s' % (self,))
 
266
        self.finished = True
 
267
 
 
268
    def _disconnect_client(self):
 
269
        """Close the current connection. We stopped due to a timeout/etc."""
 
270
        # The default implementation is a no-op, because that is all we used to
 
271
        # do when disconnecting from a client. I suppose we never had the
 
272
        # *server* initiate a disconnect, before
 
273
 
 
274
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
275
        """Wait for more bytes to be read, but timeout if none available.
 
276
 
 
277
        This allows us to detect idle connections, and stop trying to read from
 
278
        them, without setting the socket itself to non-blocking. This also
 
279
        allows us to specify when we watch for idle timeouts.
 
280
 
 
281
        :return: Did we timeout? (True if we timed out, False if there is data
 
282
            to be read)
 
283
        """
 
284
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
285
 
221
286
    def _build_protocol(self):
222
287
        """Identifies the version of the incoming request, and returns an
227
292
 
228
293
        :returns: a SmartServerRequestProtocol.
229
294
        """
 
295
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
296
        if self.finished:
 
297
            # We're stopping, so don't try to do any more work
 
298
            return None
230
299
        bytes = self._get_line()
231
300
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
301
        protocol = protocol_factory(
234
303
        protocol.accept_bytes(unused_bytes)
235
304
        return protocol
236
305
 
 
306
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
307
        """select() on a file descriptor, waiting for nonblocking read()
 
308
 
 
309
        This will raise a ConnectionTimeout exception if we do not get a
 
310
        readable handle before timeout_seconds.
 
311
        :return: None
 
312
        """
 
313
        t_end = self._timer() + timeout_seconds
 
314
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
315
        rs = xs = None
 
316
        while not rs and not xs and self._timer() < t_end:
 
317
            if self.finished:
 
318
                return
 
319
            try:
 
320
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
321
            except (select.error, socket.error) as e:
 
322
                err = getattr(e, 'errno', None)
 
323
                if err is None and getattr(e, 'args', None) is not None:
 
324
                    # select.error doesn't have 'errno', it just has args[0]
 
325
                    err = e.args[0]
 
326
                if err in _bad_file_descriptor:
 
327
                    return  # Not a socket indicates read() will fail
 
328
                elif err == errno.EINTR:
 
329
                    # Interrupted, keep looping.
 
330
                    continue
 
331
                raise
 
332
            except ValueError:
 
333
                return  # Socket may already be closed
 
334
        if rs or xs:
 
335
            return
 
336
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
337
                                       % (timeout_seconds,))
 
338
 
237
339
    def _serve_one_request(self, protocol):
238
340
        """Read one request from input, process, send back a response.
239
341
 
240
342
        :param protocol: a SmartServerRequestProtocol.
241
343
        """
 
344
        if protocol is None:
 
345
            return
242
346
        try:
243
347
            self._serve_one_request_unguarded(protocol)
244
348
        except KeyboardInterrupt:
245
349
            raise
246
 
        except Exception, e:
 
350
        except Exception as e:
247
351
            self.terminate_due_to_error()
248
352
 
249
353
    def terminate_due_to_error(self):
260
364
 
261
365
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
366
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
367
    def __init__(self, sock, backing_transport, root_client_path='/',
 
368
                 timeout=None):
264
369
        """Constructor.
265
370
 
266
371
        :param sock: the socket the server will read from.  It will be put
267
372
            into blocking mode.
268
373
        """
269
374
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
375
            self, backing_transport, root_client_path=root_client_path,
 
376
            timeout=timeout)
271
377
        sock.setblocking(True)
272
378
        self.socket = sock
 
379
        # Get the getpeername now, as we might be closed later when we care.
 
380
        try:
 
381
            self._client_info = sock.getpeername()
 
382
        except socket.error:
 
383
            self._client_info = '<unknown>'
 
384
 
 
385
    def __str__(self):
 
386
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
387
 
 
388
    def __repr__(self):
 
389
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
390
                                     self._client_info)
273
391
 
274
392
    def _serve_one_request_unguarded(self, protocol):
275
393
        while protocol.next_read_size():
277
395
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
278
396
            # short read immediately rather than block.
279
397
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
280
 
            if bytes == '':
 
398
            if bytes == b'':
281
399
                self.finished = True
282
400
                return
283
401
            protocol.accept_bytes(bytes)
284
402
 
285
403
        self._push_back(protocol.unused_data)
286
404
 
 
405
    def _disconnect_client(self):
 
406
        """Close the current connection. We stopped due to a timeout/etc."""
 
407
        self.socket.close()
 
408
 
 
409
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
410
        """Wait for more bytes to be read, but timeout if none available.
 
411
 
 
412
        This allows us to detect idle connections, and stop trying to read from
 
413
        them, without setting the socket itself to non-blocking. This also
 
414
        allows us to specify when we watch for idle timeouts.
 
415
 
 
416
        :return: None, this will raise ConnectionTimeout if we time out before
 
417
            data is available.
 
418
        """
 
419
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
420
 
287
421
    def _read_bytes(self, desired_count):
288
422
        return osutils.read_bytes_from_socket(
289
423
            self.socket, self._report_activity)
306
440
 
307
441
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
442
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
443
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
444
        """Construct new server.
311
445
 
312
446
        :param in_file: Python file from which requests can be read.
313
447
        :param out_file: Python file to write responses.
314
448
        :param backing_transport: Transport for the directory served.
315
449
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
450
        SmartServerStreamMedium.__init__(self, backing_transport,
 
451
                                         timeout=timeout)
317
452
        if sys.platform == 'win32':
318
453
            # force binary mode for files
319
454
            import msvcrt
324
459
        self._in = in_file
325
460
        self._out = out_file
326
461
 
 
462
    def serve(self):
 
463
        """See SmartServerStreamMedium.serve"""
 
464
        # This is the regular serve, except it adds signal trapping for soft
 
465
        # shutdown.
 
466
        stop_gracefully = self._stop_gracefully
 
467
        signals.register_on_hangup(id(self), stop_gracefully)
 
468
        try:
 
469
            return super(SmartServerPipeStreamMedium, self).serve()
 
470
        finally:
 
471
            signals.unregister_on_hangup(id(self))
 
472
 
327
473
    def _serve_one_request_unguarded(self, protocol):
328
474
        while True:
329
475
            # We need to be careful not to read past the end of the current
335
481
                self._out.flush()
336
482
                return
337
483
            bytes = self.read_bytes(bytes_to_read)
338
 
            if bytes == '':
 
484
            if bytes == b'':
339
485
                # Connection has been closed.
340
486
                self.finished = True
341
487
                self._out.flush()
342
488
                return
343
489
            protocol.accept_bytes(bytes)
344
490
 
 
491
    def _disconnect_client(self):
 
492
        self._in.close()
 
493
        self._out.flush()
 
494
        self._out.close()
 
495
 
 
496
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
497
        """Wait for more bytes to be read, but timeout if none available.
 
498
 
 
499
        This allows us to detect idle connections, and stop trying to read from
 
500
        them, without setting the socket itself to non-blocking. This also
 
501
        allows us to specify when we watch for idle timeouts.
 
502
 
 
503
        :return: None, this will raise ConnectionTimeout if we time out before
 
504
            data is available.
 
505
        """
 
506
        if (getattr(self._in, 'fileno', None) is None
 
507
                or sys.platform == 'win32'):
 
508
            # You can't select() file descriptors on Windows.
 
509
            return
 
510
        try:
 
511
            return self._wait_on_descriptor(self._in, timeout_seconds)
 
512
        except io.UnsupportedOperation:
 
513
            return
 
514
 
345
515
    def _read_bytes(self, desired_count):
346
516
        return self._in.read(desired_count)
347
517
 
475
645
 
476
646
    def read_line(self):
477
647
        line = self._read_line()
478
 
        if not line.endswith('\n'):
 
648
        if not line.endswith(b'\n'):
479
649
            # end of file encountered reading from server
480
650
            raise errors.ConnectionReset(
481
651
                "Unexpected end of message. Please check connectivity "
491
661
        return self._medium._get_line()
492
662
 
493
663
 
 
664
class _VfsRefuser(object):
 
665
    """An object that refuses all VFS requests.
 
666
 
 
667
    """
 
668
 
 
669
    def __init__(self):
 
670
        client._SmartClient.hooks.install_named_hook(
 
671
            'call', self.check_vfs, 'vfs refuser')
 
672
 
 
673
    def check_vfs(self, params):
 
674
        try:
 
675
            request_method = request.request_handlers.get(params.method)
 
676
        except KeyError:
 
677
            # A method we don't know about doesn't count as a VFS method.
 
678
            return
 
679
        if issubclass(request_method, vfs.VfsRequest):
 
680
            raise HpssVfsRequestNotAllowed(params.method, params.args)
 
681
 
 
682
 
494
683
class _DebugCounter(object):
495
684
    """An object that counts the HPSS calls made to each client medium.
496
685
 
497
 
    When a medium is garbage-collected, or failing that when atexit functions
498
 
    are run, the total number of calls made on that medium are reported via
499
 
    trace.note.
 
686
    When a medium is garbage-collected, or failing that when
 
687
    breezy.global_state exits, the total number of calls made on that medium
 
688
    are reported via trace.note.
500
689
    """
501
690
 
502
691
    def __init__(self):
503
692
        self.counts = weakref.WeakKeyDictionary()
504
693
        client._SmartClient.hooks.install_named_hook(
505
694
            'call', self.increment_call_count, 'hpss call counter')
506
 
        atexit.register(self.flush_all)
 
695
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
507
696
 
508
697
    def track(self, medium):
509
698
        """Start tracking calls made to a medium.
543
732
        value['count'] = 0
544
733
        value['vfs_count'] = 0
545
734
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
735
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
736
                       count, vfs_count, medium_repr))
548
737
 
549
738
    def flush_all(self):
550
739
        for ref in list(self.counts.keys()):
551
740
            self.done(ref)
552
741
 
 
742
 
553
743
_debug_counter = None
 
744
_vfs_refuser = None
554
745
 
555
746
 
556
747
class SmartClientMedium(SmartMedium):
573
764
            if _debug_counter is None:
574
765
                _debug_counter = _DebugCounter()
575
766
            _debug_counter.track(self)
 
767
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
768
            global _vfs_refuser
 
769
            if _vfs_refuser is None:
 
770
                _vfs_refuser = _VfsRefuser()
576
771
 
577
772
    def _is_remote_before(self, version_tuple):
578
773
        """Is it possible the remote side supports RPCs for a given version?
602
797
        :seealso: _is_remote_before
603
798
        """
604
799
        if (self._remote_version_is_before is not None and
605
 
            version_tuple > self._remote_version_is_before):
 
800
                version_tuple > self._remote_version_is_before):
606
801
            # We have been told that the remote side is older than some version
607
802
            # which is newer than a previously supplied older-than version.
608
803
            # This indicates that some smart verb call is not guarded
609
804
            # appropriately (it should simply not have been tried).
610
805
            trace.mutter(
611
806
                "_remember_remote_is_before(%r) called, but "
612
 
                "_remember_remote_is_before(%r) was called previously."
613
 
                , version_tuple, self._remote_version_is_before)
 
807
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
614
808
            if 'hpss' in debug.debug_flags:
615
809
                ui.ui_factory.show_warning(
616
810
                    "_remember_remote_is_before(%r) called, but "
628
822
                medium_request = self.get_request()
629
823
                # Send a 'hello' request in protocol version one, for maximum
630
824
                # backwards compatibility.
631
 
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
 
825
                client_protocol = protocol.SmartClientRequestProtocolOne(
 
826
                    medium_request)
632
827
                client_protocol.query_version()
633
828
                self._done_hello = True
634
 
            except errors.SmartProtocolError, e:
 
829
            except errors.SmartProtocolError as e:
635
830
                # Cache the error, just like we would cache a successful
636
831
                # result.
637
832
                self._protocol_version_error = e
669
864
        """
670
865
        medium_base = urlutils.join(self.base, '/')
671
866
        rel_url = urlutils.relative_url(medium_base, transport.base)
672
 
        return urllib.unquote(rel_url)
 
867
        return urlutils.unquote(rel_url)
673
868
 
674
869
 
675
870
class SmartClientStreamMedium(SmartClientMedium):
710
905
        """
711
906
        return SmartClientStreamMediumRequest(self)
712
907
 
 
908
    def reset(self):
 
909
        """We have been disconnected, reset current state.
 
910
 
 
911
        This resets things like _current_request and connected state.
 
912
        """
 
913
        self.disconnect()
 
914
        self._current_request = None
 
915
 
713
916
 
714
917
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
918
    """A client medium using simple pipes.
716
919
 
717
920
    This client does not manage the pipes: it assumes they will always be open.
718
 
 
719
 
    Note that if readable_pipe.read might raise IOError or OSError with errno
720
 
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
721
 
    (such as used for sys.stdin) are safe.
722
921
    """
723
922
 
724
923
    def __init__(self, readable_pipe, writeable_pipe, base):
726
925
        self._readable_pipe = readable_pipe
727
926
        self._writeable_pipe = writeable_pipe
728
927
 
729
 
    def _accept_bytes(self, bytes):
 
928
    def _accept_bytes(self, data):
730
929
        """See SmartClientStreamMedium.accept_bytes."""
731
 
        self._writeable_pipe.write(bytes)
732
 
        self._report_activity(len(bytes), 'write')
 
930
        try:
 
931
            self._writeable_pipe.write(data)
 
932
        except IOError as e:
 
933
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
934
                raise errors.ConnectionReset(
 
935
                    "Error trying to write to subprocess", e)
 
936
            raise
 
937
        self._report_activity(len(data), 'write')
733
938
 
734
939
    def _flush(self):
735
940
        """See SmartClientStreamMedium._flush()."""
 
941
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
942
        #       However, testing shows that even when the child process is
 
943
        #       gone, this doesn't error.
736
944
        self._writeable_pipe.flush()
737
945
 
738
946
    def _read_bytes(self, count):
739
947
        """See SmartClientStreamMedium._read_bytes."""
740
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
741
 
        self._report_activity(len(bytes), 'read')
742
 
        return bytes
 
948
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
949
        data = self._readable_pipe.read(bytes_to_read)
 
950
        self._report_activity(len(data), 'read')
 
951
        return data
 
952
 
 
953
 
 
954
class SSHParams(object):
 
955
    """A set of parameters for starting a remote bzr via SSH."""
 
956
 
 
957
    def __init__(self, host, port=None, username=None, password=None,
 
958
                 bzr_remote_path='bzr'):
 
959
        self.host = host
 
960
        self.port = port
 
961
        self.username = username
 
962
        self.password = password
 
963
        self.bzr_remote_path = bzr_remote_path
743
964
 
744
965
 
745
966
class SmartSSHClientMedium(SmartClientStreamMedium):
746
 
    """A client medium using SSH."""
747
 
 
748
 
    def __init__(self, host, port=None, username=None, password=None,
749
 
            base=None, vendor=None, bzr_remote_path=None):
 
967
    """A client medium using SSH.
 
968
 
 
969
    It delegates IO to a SmartSimplePipesClientMedium or
 
970
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
971
    """
 
972
 
 
973
    def __init__(self, base, ssh_params, vendor=None):
750
974
        """Creates a client that will connect on the first use.
751
975
 
 
976
        :param ssh_params: A SSHParams instance.
752
977
        :param vendor: An optional override for the ssh vendor to use. See
753
 
            bzrlib.transport.ssh for details on ssh vendors.
 
978
            breezy.transport.ssh for details on ssh vendors.
754
979
        """
755
 
        self._connected = False
756
 
        self._host = host
757
 
        self._password = password
758
 
        self._port = port
759
 
        self._username = username
 
980
        self._real_medium = None
 
981
        self._ssh_params = ssh_params
760
982
        # for the benefit of progress making a short description of this
761
983
        # transport
762
984
        self._scheme = 'bzr+ssh'
764
986
        # _DebugCounter so we have to store all the values used in our repr
765
987
        # method before calling the super init.
766
988
        SmartClientStreamMedium.__init__(self, base)
767
 
        self._read_from = None
 
989
        self._vendor = vendor
768
990
        self._ssh_connection = None
769
 
        self._vendor = vendor
770
 
        self._write_to = None
771
 
        self._bzr_remote_path = bzr_remote_path
772
991
 
773
992
    def __repr__(self):
774
 
        if self._port is None:
 
993
        if self._ssh_params.port is None:
775
994
            maybe_port = ''
776
995
        else:
777
 
            maybe_port = ':%s' % self._port
778
 
        return "%s(%s://%s@%s%s/)" % (
 
996
            maybe_port = ':%s' % self._ssh_params.port
 
997
        if self._ssh_params.username is None:
 
998
            maybe_user = ''
 
999
        else:
 
1000
            maybe_user = '%s@' % self._ssh_params.username
 
1001
        return "%s(%s://%s%s%s/)" % (
779
1002
            self.__class__.__name__,
780
1003
            self._scheme,
781
 
            self._username,
782
 
            self._host,
 
1004
            maybe_user,
 
1005
            self._ssh_params.host,
783
1006
            maybe_port)
784
1007
 
785
1008
    def _accept_bytes(self, bytes):
786
1009
        """See SmartClientStreamMedium.accept_bytes."""
787
1010
        self._ensure_connection()
788
 
        self._write_to.write(bytes)
789
 
        self._report_activity(len(bytes), 'write')
 
1011
        self._real_medium.accept_bytes(bytes)
790
1012
 
791
1013
    def disconnect(self):
792
1014
        """See SmartClientMedium.disconnect()."""
793
 
        if not self._connected:
794
 
            return
795
 
        self._read_from.close()
796
 
        self._write_to.close()
797
 
        self._ssh_connection.close()
798
 
        self._connected = False
 
1015
        if self._real_medium is not None:
 
1016
            self._real_medium.disconnect()
 
1017
            self._real_medium = None
 
1018
        if self._ssh_connection is not None:
 
1019
            self._ssh_connection.close()
 
1020
            self._ssh_connection = None
799
1021
 
800
1022
    def _ensure_connection(self):
801
1023
        """Connect this medium if not already connected."""
802
 
        if self._connected:
 
1024
        if self._real_medium is not None:
803
1025
            return
804
1026
        if self._vendor is None:
805
1027
            vendor = ssh._get_ssh_vendor()
806
1028
        else:
807
1029
            vendor = self._vendor
808
 
        self._ssh_connection = vendor.connect_ssh(self._username,
809
 
                self._password, self._host, self._port,
810
 
                command=[self._bzr_remote_path, 'serve', '--inet',
811
 
                         '--directory=/', '--allow-writes'])
812
 
        self._read_from, self._write_to = \
813
 
            self._ssh_connection.get_filelike_channels()
814
 
        self._connected = True
 
1030
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1031
                                                  self._ssh_params.password, self._ssh_params.host,
 
1032
                                                  self._ssh_params.port,
 
1033
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
1034
                                                           '--directory=/', '--allow-writes'])
 
1035
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1036
        if io_kind == 'socket':
 
1037
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1038
                self.base, io_object)
 
1039
        elif io_kind == 'pipes':
 
1040
            read_from, write_to = io_object
 
1041
            self._real_medium = SmartSimplePipesClientMedium(
 
1042
                read_from, write_to, self.base)
 
1043
        else:
 
1044
            raise AssertionError(
 
1045
                "Unexpected io_kind %r from %r"
 
1046
                % (io_kind, self._ssh_connection))
 
1047
        for hook in transport.Transport.hooks["post_connect"]:
 
1048
            hook(self)
815
1049
 
816
1050
    def _flush(self):
817
1051
        """See SmartClientStreamMedium._flush()."""
818
 
        self._write_to.flush()
 
1052
        self._real_medium._flush()
819
1053
 
820
1054
    def _read_bytes(self, count):
821
1055
        """See SmartClientStreamMedium.read_bytes."""
822
 
        if not self._connected:
 
1056
        if self._real_medium is None:
823
1057
            raise errors.MediumNotConnected(self)
824
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
825
 
        bytes = self._read_from.read(bytes_to_read)
826
 
        self._report_activity(len(bytes), 'read')
827
 
        return bytes
 
1058
        return self._real_medium.read_bytes(count)
828
1059
 
829
1060
 
830
1061
# Port 4155 is the default port for bzr://, registered with IANA.
832
1063
BZR_DEFAULT_PORT = 4155
833
1064
 
834
1065
 
835
 
class SmartTCPClientMedium(SmartClientStreamMedium):
836
 
    """A client medium using TCP."""
837
 
 
838
 
    def __init__(self, host, port, base):
839
 
        """Creates a client that will connect on the first use."""
 
1066
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1067
    """A client medium using a socket.
 
1068
 
 
1069
    This class isn't usable directly.  Use one of its subclasses instead.
 
1070
    """
 
1071
 
 
1072
    def __init__(self, base):
840
1073
        SmartClientStreamMedium.__init__(self, base)
 
1074
        self._socket = None
841
1075
        self._connected = False
842
 
        self._host = host
843
 
        self._port = port
844
 
        self._socket = None
845
1076
 
846
1077
    def _accept_bytes(self, bytes):
847
1078
        """See SmartClientMedium.accept_bytes."""
848
1079
        self._ensure_connection()
849
1080
        osutils.send_all(self._socket, bytes, self._report_activity)
850
1081
 
 
1082
    def _ensure_connection(self):
 
1083
        """Connect this medium if not already connected."""
 
1084
        raise NotImplementedError(self._ensure_connection)
 
1085
 
 
1086
    def _flush(self):
 
1087
        """See SmartClientStreamMedium._flush().
 
1088
 
 
1089
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1090
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1091
        future.
 
1092
        """
 
1093
 
 
1094
    def _read_bytes(self, count):
 
1095
        """See SmartClientMedium.read_bytes."""
 
1096
        if not self._connected:
 
1097
            raise errors.MediumNotConnected(self)
 
1098
        return osutils.read_bytes_from_socket(
 
1099
            self._socket, self._report_activity)
 
1100
 
851
1101
    def disconnect(self):
852
1102
        """See SmartClientMedium.disconnect()."""
853
1103
        if not self._connected:
856
1106
        self._socket = None
857
1107
        self._connected = False
858
1108
 
 
1109
 
 
1110
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1111
    """A client medium that creates a TCP connection."""
 
1112
 
 
1113
    def __init__(self, host, port, base):
 
1114
        """Creates a client that will connect on the first use."""
 
1115
        SmartClientSocketMedium.__init__(self, base)
 
1116
        self._host = host
 
1117
        self._port = port
 
1118
 
859
1119
    def _ensure_connection(self):
860
1120
        """Connect this medium if not already connected."""
861
1121
        if self._connected:
866
1126
            port = int(self._port)
867
1127
        try:
868
1128
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
869
 
                socket.SOCK_STREAM, 0, 0)
870
 
        except socket.gaierror, (err_num, err_msg):
 
1129
                                           socket.SOCK_STREAM, 0, 0)
 
1130
        except socket.gaierror as xxx_todo_changeme:
 
1131
            (err_num, err_msg) = xxx_todo_changeme.args
871
1132
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
872
 
                    (self._host, port, err_msg))
 
1133
                                         (self._host, port, err_msg))
873
1134
        # Initialize err in case there are no addresses returned:
874
 
        err = socket.error("no address found for %s" % self._host)
 
1135
        last_err = socket.error("no address found for %s" % self._host)
875
1136
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
876
1137
            try:
877
1138
                self._socket = socket.socket(family, socktype, proto)
878
1139
                self._socket.setsockopt(socket.IPPROTO_TCP,
879
1140
                                        socket.TCP_NODELAY, 1)
880
1141
                self._socket.connect(sockaddr)
881
 
            except socket.error, err:
 
1142
            except socket.error as err:
882
1143
                if self._socket is not None:
883
1144
                    self._socket.close()
884
1145
                self._socket = None
 
1146
                last_err = err
885
1147
                continue
886
1148
            break
887
1149
        if self._socket is None:
888
1150
            # socket errors either have a (string) or (errno, string) as their
889
1151
            # args.
890
 
            if type(err.args) is str:
891
 
                err_msg = err.args
 
1152
            if isinstance(last_err.args, str):
 
1153
                err_msg = last_err.args
892
1154
            else:
893
 
                err_msg = err.args[1]
 
1155
                err_msg = last_err.args[1]
894
1156
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
895
 
                    (self._host, port, err_msg))
896
 
        self._connected = True
897
 
 
898
 
    def _flush(self):
899
 
        """See SmartClientStreamMedium._flush().
900
 
 
901
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
902
 
        add a means to do a flush, but that can be done in the future.
903
 
        """
904
 
 
905
 
    def _read_bytes(self, count):
906
 
        """See SmartClientMedium.read_bytes."""
907
 
        if not self._connected:
908
 
            raise errors.MediumNotConnected(self)
909
 
        return osutils.read_bytes_from_socket(
910
 
            self._socket, self._report_activity)
 
1157
                                         (self._host, port, err_msg))
 
1158
        self._connected = True
 
1159
        for hook in transport.Transport.hooks["post_connect"]:
 
1160
            hook(self)
 
1161
 
 
1162
 
 
1163
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1164
    """A client medium for an already connected socket.
 
1165
 
 
1166
    Note that this class will assume it "owns" the socket, so it will close it
 
1167
    when its disconnect method is called.
 
1168
    """
 
1169
 
 
1170
    def __init__(self, base, sock):
 
1171
        SmartClientSocketMedium.__init__(self, base)
 
1172
        self._socket = sock
 
1173
        self._connected = True
 
1174
 
 
1175
    def _ensure_connection(self):
 
1176
        # Already connected, by definition!  So nothing to do.
 
1177
        pass
911
1178
 
912
1179
 
913
1180
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
948
1215
        This invokes self._medium._flush to ensure all bytes are transmitted.
949
1216
        """
950
1217
        self._medium._flush()
951
 
 
952