/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to breezy/bzr/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2019-06-03 21:25:01 UTC
  • mto: This revision was merged to the branch mainline in revision 7318.
  • Revision ID: jelmer@jelmer.uk-20190603212501-zgt2czrlc6oqoi7a
Fix tests on python 2.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
bzrlib/transport/smart/__init__.py.
 
24
breezy/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
 
30
import io
27
31
import os
28
32
import sys
29
 
import urllib
30
 
 
31
 
from bzrlib.lazy_import import lazy_import
 
33
import time
 
34
 
 
35
try:
 
36
    import _thread
 
37
except ImportError:
 
38
    import thread as _thread
 
39
 
 
40
import breezy
 
41
from ...lazy_import import lazy_import
32
42
lazy_import(globals(), """
33
 
import atexit
 
43
import select
34
44
import socket
35
 
import thread
36
45
import weakref
37
46
 
38
 
from bzrlib import (
 
47
from breezy import (
39
48
    debug,
40
 
    errors,
41
 
    symbol_versioning,
42
49
    trace,
 
50
    transport,
43
51
    ui,
44
52
    urlutils,
45
53
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
47
 
from bzrlib.transport import ssh
 
54
from breezy.i18n import gettext
 
55
from breezy.bzr.smart import client, protocol, request, signals, vfs
 
56
from breezy.transport import ssh
48
57
""")
49
 
from bzrlib import osutils
 
58
from ... import (
 
59
    errors,
 
60
    osutils,
 
61
    )
50
62
 
51
63
# Throughout this module buffer size parameters are either limited to be at
52
64
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
54
66
# from non-sockets as well.
55
67
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
68
 
 
69
 
 
70
class HpssVfsRequestNotAllowed(errors.BzrError):
 
71
 
 
72
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
 
73
            "%(method)s, %(arguments)s.")
 
74
 
 
75
    def __init__(self, method, arguments):
 
76
        self.method = method
 
77
        self.arguments = arguments
 
78
 
 
79
 
57
80
def _get_protocol_factory_for_bytes(bytes):
58
81
    """Determine the right protocol factory for 'bytes'.
59
82
 
95
118
    :returns: a tuple of two strs: (line, excess)
96
119
    """
97
120
    newline_pos = -1
98
 
    bytes = ''
 
121
    bytes = b''
99
122
    while newline_pos == -1:
100
123
        new_bytes = read_bytes_func(1)
101
124
        bytes += new_bytes
102
 
        if new_bytes == '':
 
125
        if new_bytes == b'':
103
126
            # Ran out of bytes before receiving a complete line.
104
 
            return bytes, ''
105
 
        newline_pos = bytes.find('\n')
106
 
    line = bytes[:newline_pos+1]
107
 
    excess = bytes[newline_pos+1:]
 
127
            return bytes, b''
 
128
        newline_pos = bytes.find(b'\n')
 
129
    line = bytes[:newline_pos + 1]
 
130
    excess = bytes[newline_pos + 1:]
108
131
    return line, excess
109
132
 
110
133
 
114
137
    def __init__(self):
115
138
        self._push_back_buffer = None
116
139
 
117
 
    def _push_back(self, bytes):
 
140
    def _push_back(self, data):
118
141
        """Return unused bytes to the medium, because they belong to the next
119
142
        request(s).
120
143
 
121
144
        This sets the _push_back_buffer to the given bytes.
122
145
        """
 
146
        if not isinstance(data, bytes):
 
147
            raise TypeError(data)
123
148
        if self._push_back_buffer is not None:
124
149
            raise AssertionError(
125
150
                "_push_back called when self._push_back_buffer is %r"
126
151
                % (self._push_back_buffer,))
127
 
        if bytes == '':
 
152
        if data == b'':
128
153
            return
129
 
        self._push_back_buffer = bytes
 
154
        self._push_back_buffer = data
130
155
 
131
156
    def _get_push_back_buffer(self):
132
 
        if self._push_back_buffer == '':
 
157
        if self._push_back_buffer == b'':
133
158
            raise AssertionError(
134
159
                '%s._push_back_buffer should never be the empty string, '
135
160
                'which can be confused with EOF' % (self,))
176
201
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
202
 
178
203
 
 
204
_bad_file_descriptor = (errno.EBADF,)
 
205
if sys.platform == 'win32':
 
206
    # Given on Windows if you pass a closed socket to select.select. Probably
 
207
    # also given if you pass a file handle to select.
 
208
    WSAENOTSOCK = 10038
 
209
    _bad_file_descriptor += (WSAENOTSOCK,)
 
210
 
 
211
 
179
212
class SmartServerStreamMedium(SmartMedium):
180
213
    """Handles smart commands coming over a stream.
181
214
 
194
227
        the stream.  See also the _push_back method.
195
228
    """
196
229
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
230
    _timer = time.time
 
231
 
 
232
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
233
        """Construct new server.
199
234
 
200
235
        :param backing_transport: Transport for the directory served.
203
238
        self.backing_transport = backing_transport
204
239
        self.root_client_path = root_client_path
205
240
        self.finished = False
 
241
        if timeout is None:
 
242
            raise AssertionError('You must supply a timeout.')
 
243
        self._client_timeout = timeout
 
244
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
245
        SmartMedium.__init__(self)
207
246
 
208
247
    def serve(self):
214
253
            while not self.finished:
215
254
                server_protocol = self._build_protocol()
216
255
                self._serve_one_request(server_protocol)
217
 
        except Exception, e:
 
256
        except errors.ConnectionTimeout as e:
 
257
            trace.note('%s' % (e,))
 
258
            trace.log_exception_quietly()
 
259
            self._disconnect_client()
 
260
            # We reported it, no reason to make a big fuss.
 
261
            return
 
262
        except Exception as e:
218
263
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
264
            raise
 
265
        self._disconnect_client()
 
266
 
 
267
    def _stop_gracefully(self):
 
268
        """When we finish this message, stop looking for more."""
 
269
        trace.mutter('Stopping %s' % (self,))
 
270
        self.finished = True
 
271
 
 
272
    def _disconnect_client(self):
 
273
        """Close the current connection. We stopped due to a timeout/etc."""
 
274
        # The default implementation is a no-op, because that is all we used to
 
275
        # do when disconnecting from a client. I suppose we never had the
 
276
        # *server* initiate a disconnect, before
 
277
 
 
278
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
279
        """Wait for more bytes to be read, but timeout if none available.
 
280
 
 
281
        This allows us to detect idle connections, and stop trying to read from
 
282
        them, without setting the socket itself to non-blocking. This also
 
283
        allows us to specify when we watch for idle timeouts.
 
284
 
 
285
        :return: Did we timeout? (True if we timed out, False if there is data
 
286
            to be read)
 
287
        """
 
288
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
289
 
221
290
    def _build_protocol(self):
222
291
        """Identifies the version of the incoming request, and returns an
227
296
 
228
297
        :returns: a SmartServerRequestProtocol.
229
298
        """
 
299
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
300
        if self.finished:
 
301
            # We're stopping, so don't try to do any more work
 
302
            return None
230
303
        bytes = self._get_line()
231
304
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
305
        protocol = protocol_factory(
234
307
        protocol.accept_bytes(unused_bytes)
235
308
        return protocol
236
309
 
 
310
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
311
        """select() on a file descriptor, waiting for nonblocking read()
 
312
 
 
313
        This will raise a ConnectionTimeout exception if we do not get a
 
314
        readable handle before timeout_seconds.
 
315
        :return: None
 
316
        """
 
317
        t_end = self._timer() + timeout_seconds
 
318
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
319
        rs = xs = None
 
320
        while not rs and not xs and self._timer() < t_end:
 
321
            if self.finished:
 
322
                return
 
323
            try:
 
324
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
325
            except (select.error, socket.error) as e:
 
326
                err = getattr(e, 'errno', None)
 
327
                if err is None and getattr(e, 'args', None) is not None:
 
328
                    # select.error doesn't have 'errno', it just has args[0]
 
329
                    err = e.args[0]
 
330
                if err in _bad_file_descriptor:
 
331
                    return  # Not a socket indicates read() will fail
 
332
                elif err == errno.EINTR:
 
333
                    # Interrupted, keep looping.
 
334
                    continue
 
335
                raise
 
336
            except ValueError:
 
337
                return  # Socket may already be closed
 
338
        if rs or xs:
 
339
            return
 
340
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
341
                                       % (timeout_seconds,))
 
342
 
237
343
    def _serve_one_request(self, protocol):
238
344
        """Read one request from input, process, send back a response.
239
345
 
240
346
        :param protocol: a SmartServerRequestProtocol.
241
347
        """
 
348
        if protocol is None:
 
349
            return
242
350
        try:
243
351
            self._serve_one_request_unguarded(protocol)
244
352
        except KeyboardInterrupt:
245
353
            raise
246
 
        except Exception, e:
 
354
        except Exception as e:
247
355
            self.terminate_due_to_error()
248
356
 
249
357
    def terminate_due_to_error(self):
260
368
 
261
369
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
370
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
371
    def __init__(self, sock, backing_transport, root_client_path='/',
 
372
                 timeout=None):
264
373
        """Constructor.
265
374
 
266
375
        :param sock: the socket the server will read from.  It will be put
267
376
            into blocking mode.
268
377
        """
269
378
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
379
            self, backing_transport, root_client_path=root_client_path,
 
380
            timeout=timeout)
271
381
        sock.setblocking(True)
272
382
        self.socket = sock
 
383
        # Get the getpeername now, as we might be closed later when we care.
 
384
        try:
 
385
            self._client_info = sock.getpeername()
 
386
        except socket.error:
 
387
            self._client_info = '<unknown>'
 
388
 
 
389
    def __str__(self):
 
390
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
391
 
 
392
    def __repr__(self):
 
393
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
394
                                     self._client_info)
273
395
 
274
396
    def _serve_one_request_unguarded(self, protocol):
275
397
        while protocol.next_read_size():
277
399
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
278
400
            # short read immediately rather than block.
279
401
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
280
 
            if bytes == '':
 
402
            if bytes == b'':
281
403
                self.finished = True
282
404
                return
283
405
            protocol.accept_bytes(bytes)
284
406
 
285
407
        self._push_back(protocol.unused_data)
286
408
 
 
409
    def _disconnect_client(self):
 
410
        """Close the current connection. We stopped due to a timeout/etc."""
 
411
        self.socket.close()
 
412
 
 
413
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
414
        """Wait for more bytes to be read, but timeout if none available.
 
415
 
 
416
        This allows us to detect idle connections, and stop trying to read from
 
417
        them, without setting the socket itself to non-blocking. This also
 
418
        allows us to specify when we watch for idle timeouts.
 
419
 
 
420
        :return: None, this will raise ConnectionTimeout if we time out before
 
421
            data is available.
 
422
        """
 
423
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
424
 
287
425
    def _read_bytes(self, desired_count):
288
426
        return osutils.read_bytes_from_socket(
289
427
            self.socket, self._report_activity)
295
433
        self.finished = True
296
434
 
297
435
    def _write_out(self, bytes):
298
 
        tstart = osutils.timer_func()
 
436
        tstart = osutils.perf_counter()
299
437
        osutils.send_all(self.socket, bytes, self._report_activity)
300
438
        if 'hpss' in debug.debug_flags:
301
 
            thread_id = thread.get_ident()
 
439
            thread_id = _thread.get_ident()
302
440
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
303
441
                         % ('wrote', thread_id, len(bytes),
304
 
                            osutils.timer_func() - tstart))
 
442
                            osutils.perf_counter() - tstart))
305
443
 
306
444
 
307
445
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
446
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
447
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
448
        """Construct new server.
311
449
 
312
450
        :param in_file: Python file from which requests can be read.
313
451
        :param out_file: Python file to write responses.
314
452
        :param backing_transport: Transport for the directory served.
315
453
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
454
        SmartServerStreamMedium.__init__(self, backing_transport,
 
455
                                         timeout=timeout)
317
456
        if sys.platform == 'win32':
318
457
            # force binary mode for files
319
458
            import msvcrt
324
463
        self._in = in_file
325
464
        self._out = out_file
326
465
 
 
466
    def serve(self):
 
467
        """See SmartServerStreamMedium.serve"""
 
468
        # This is the regular serve, except it adds signal trapping for soft
 
469
        # shutdown.
 
470
        stop_gracefully = self._stop_gracefully
 
471
        signals.register_on_hangup(id(self), stop_gracefully)
 
472
        try:
 
473
            return super(SmartServerPipeStreamMedium, self).serve()
 
474
        finally:
 
475
            signals.unregister_on_hangup(id(self))
 
476
 
327
477
    def _serve_one_request_unguarded(self, protocol):
328
478
        while True:
329
479
            # We need to be careful not to read past the end of the current
335
485
                self._out.flush()
336
486
                return
337
487
            bytes = self.read_bytes(bytes_to_read)
338
 
            if bytes == '':
 
488
            if bytes == b'':
339
489
                # Connection has been closed.
340
490
                self.finished = True
341
491
                self._out.flush()
342
492
                return
343
493
            protocol.accept_bytes(bytes)
344
494
 
 
495
    def _disconnect_client(self):
 
496
        self._in.close()
 
497
        self._out.flush()
 
498
        self._out.close()
 
499
 
 
500
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
501
        """Wait for more bytes to be read, but timeout if none available.
 
502
 
 
503
        This allows us to detect idle connections, and stop trying to read from
 
504
        them, without setting the socket itself to non-blocking. This also
 
505
        allows us to specify when we watch for idle timeouts.
 
506
 
 
507
        :return: None, this will raise ConnectionTimeout if we time out before
 
508
            data is available.
 
509
        """
 
510
        if (getattr(self._in, 'fileno', None) is None
 
511
                or sys.platform == 'win32'):
 
512
            # You can't select() file descriptors on Windows.
 
513
            return
 
514
        try:
 
515
            return self._wait_on_descriptor(self._in, timeout_seconds)
 
516
        except io.UnsupportedOperation:
 
517
            return
 
518
 
345
519
    def _read_bytes(self, desired_count):
346
520
        return self._in.read(desired_count)
347
521
 
475
649
 
476
650
    def read_line(self):
477
651
        line = self._read_line()
478
 
        if not line.endswith('\n'):
 
652
        if not line.endswith(b'\n'):
479
653
            # end of file encountered reading from server
480
654
            raise errors.ConnectionReset(
481
655
                "Unexpected end of message. Please check connectivity "
491
665
        return self._medium._get_line()
492
666
 
493
667
 
 
668
class _VfsRefuser(object):
 
669
    """An object that refuses all VFS requests.
 
670
 
 
671
    """
 
672
 
 
673
    def __init__(self):
 
674
        client._SmartClient.hooks.install_named_hook(
 
675
            'call', self.check_vfs, 'vfs refuser')
 
676
 
 
677
    def check_vfs(self, params):
 
678
        try:
 
679
            request_method = request.request_handlers.get(params.method)
 
680
        except KeyError:
 
681
            # A method we don't know about doesn't count as a VFS method.
 
682
            return
 
683
        if issubclass(request_method, vfs.VfsRequest):
 
684
            raise HpssVfsRequestNotAllowed(params.method, params.args)
 
685
 
 
686
 
494
687
class _DebugCounter(object):
495
688
    """An object that counts the HPSS calls made to each client medium.
496
689
 
497
 
    When a medium is garbage-collected, or failing that when atexit functions
498
 
    are run, the total number of calls made on that medium are reported via
499
 
    trace.note.
 
690
    When a medium is garbage-collected, or failing that when
 
691
    breezy.global_state exits, the total number of calls made on that medium
 
692
    are reported via trace.note.
500
693
    """
501
694
 
502
695
    def __init__(self):
503
696
        self.counts = weakref.WeakKeyDictionary()
504
697
        client._SmartClient.hooks.install_named_hook(
505
698
            'call', self.increment_call_count, 'hpss call counter')
506
 
        atexit.register(self.flush_all)
 
699
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
507
700
 
508
701
    def track(self, medium):
509
702
        """Start tracking calls made to a medium.
543
736
        value['count'] = 0
544
737
        value['vfs_count'] = 0
545
738
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
739
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
740
                       count, vfs_count, medium_repr))
548
741
 
549
742
    def flush_all(self):
550
743
        for ref in list(self.counts.keys()):
551
744
            self.done(ref)
552
745
 
 
746
 
553
747
_debug_counter = None
 
748
_vfs_refuser = None
554
749
 
555
750
 
556
751
class SmartClientMedium(SmartMedium):
573
768
            if _debug_counter is None:
574
769
                _debug_counter = _DebugCounter()
575
770
            _debug_counter.track(self)
 
771
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
772
            global _vfs_refuser
 
773
            if _vfs_refuser is None:
 
774
                _vfs_refuser = _VfsRefuser()
576
775
 
577
776
    def _is_remote_before(self, version_tuple):
578
777
        """Is it possible the remote side supports RPCs for a given version?
602
801
        :seealso: _is_remote_before
603
802
        """
604
803
        if (self._remote_version_is_before is not None and
605
 
            version_tuple > self._remote_version_is_before):
 
804
                version_tuple > self._remote_version_is_before):
606
805
            # We have been told that the remote side is older than some version
607
806
            # which is newer than a previously supplied older-than version.
608
807
            # This indicates that some smart verb call is not guarded
609
808
            # appropriately (it should simply not have been tried).
610
809
            trace.mutter(
611
810
                "_remember_remote_is_before(%r) called, but "
612
 
                "_remember_remote_is_before(%r) was called previously."
613
 
                , version_tuple, self._remote_version_is_before)
 
811
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
614
812
            if 'hpss' in debug.debug_flags:
615
813
                ui.ui_factory.show_warning(
616
814
                    "_remember_remote_is_before(%r) called, but "
628
826
                medium_request = self.get_request()
629
827
                # Send a 'hello' request in protocol version one, for maximum
630
828
                # backwards compatibility.
631
 
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
 
829
                client_protocol = protocol.SmartClientRequestProtocolOne(
 
830
                    medium_request)
632
831
                client_protocol.query_version()
633
832
                self._done_hello = True
634
 
            except errors.SmartProtocolError, e:
 
833
            except errors.SmartProtocolError as e:
635
834
                # Cache the error, just like we would cache a successful
636
835
                # result.
637
836
                self._protocol_version_error = e
669
868
        """
670
869
        medium_base = urlutils.join(self.base, '/')
671
870
        rel_url = urlutils.relative_url(medium_base, transport.base)
672
 
        return urllib.unquote(rel_url)
 
871
        return urlutils.unquote(rel_url)
673
872
 
674
873
 
675
874
class SmartClientStreamMedium(SmartClientMedium):
710
909
        """
711
910
        return SmartClientStreamMediumRequest(self)
712
911
 
 
912
    def reset(self):
 
913
        """We have been disconnected, reset current state.
 
914
 
 
915
        This resets things like _current_request and connected state.
 
916
        """
 
917
        self.disconnect()
 
918
        self._current_request = None
 
919
 
713
920
 
714
921
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
922
    """A client medium using simple pipes.
716
923
 
717
924
    This client does not manage the pipes: it assumes they will always be open.
718
 
 
719
 
    Note that if readable_pipe.read might raise IOError or OSError with errno
720
 
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
721
 
    (such as used for sys.stdin) are safe.
722
925
    """
723
926
 
724
927
    def __init__(self, readable_pipe, writeable_pipe, base):
726
929
        self._readable_pipe = readable_pipe
727
930
        self._writeable_pipe = writeable_pipe
728
931
 
729
 
    def _accept_bytes(self, bytes):
 
932
    def _accept_bytes(self, data):
730
933
        """See SmartClientStreamMedium.accept_bytes."""
731
 
        self._writeable_pipe.write(bytes)
732
 
        self._report_activity(len(bytes), 'write')
 
934
        try:
 
935
            self._writeable_pipe.write(data)
 
936
        except IOError as e:
 
937
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
938
                raise errors.ConnectionReset(
 
939
                    "Error trying to write to subprocess", e)
 
940
            raise
 
941
        self._report_activity(len(data), 'write')
733
942
 
734
943
    def _flush(self):
735
944
        """See SmartClientStreamMedium._flush()."""
 
945
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
946
        #       However, testing shows that even when the child process is
 
947
        #       gone, this doesn't error.
736
948
        self._writeable_pipe.flush()
737
949
 
738
950
    def _read_bytes(self, count):
739
951
        """See SmartClientStreamMedium._read_bytes."""
740
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
741
 
        self._report_activity(len(bytes), 'read')
742
 
        return bytes
 
952
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
953
        data = self._readable_pipe.read(bytes_to_read)
 
954
        self._report_activity(len(data), 'read')
 
955
        return data
 
956
 
 
957
 
 
958
class SSHParams(object):
 
959
    """A set of parameters for starting a remote bzr via SSH."""
 
960
 
 
961
    def __init__(self, host, port=None, username=None, password=None,
 
962
                 bzr_remote_path='bzr'):
 
963
        self.host = host
 
964
        self.port = port
 
965
        self.username = username
 
966
        self.password = password
 
967
        self.bzr_remote_path = bzr_remote_path
743
968
 
744
969
 
745
970
class SmartSSHClientMedium(SmartClientStreamMedium):
746
 
    """A client medium using SSH."""
747
 
 
748
 
    def __init__(self, host, port=None, username=None, password=None,
749
 
            base=None, vendor=None, bzr_remote_path=None):
 
971
    """A client medium using SSH.
 
972
 
 
973
    It delegates IO to a SmartSimplePipesClientMedium or
 
974
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
975
    """
 
976
 
 
977
    def __init__(self, base, ssh_params, vendor=None):
750
978
        """Creates a client that will connect on the first use.
751
979
 
 
980
        :param ssh_params: A SSHParams instance.
752
981
        :param vendor: An optional override for the ssh vendor to use. See
753
 
            bzrlib.transport.ssh for details on ssh vendors.
 
982
            breezy.transport.ssh for details on ssh vendors.
754
983
        """
755
 
        self._connected = False
756
 
        self._host = host
757
 
        self._password = password
758
 
        self._port = port
759
 
        self._username = username
 
984
        self._real_medium = None
 
985
        self._ssh_params = ssh_params
760
986
        # for the benefit of progress making a short description of this
761
987
        # transport
762
988
        self._scheme = 'bzr+ssh'
764
990
        # _DebugCounter so we have to store all the values used in our repr
765
991
        # method before calling the super init.
766
992
        SmartClientStreamMedium.__init__(self, base)
767
 
        self._read_from = None
 
993
        self._vendor = vendor
768
994
        self._ssh_connection = None
769
 
        self._vendor = vendor
770
 
        self._write_to = None
771
 
        self._bzr_remote_path = bzr_remote_path
772
995
 
773
996
    def __repr__(self):
774
 
        if self._port is None:
 
997
        if self._ssh_params.port is None:
775
998
            maybe_port = ''
776
999
        else:
777
 
            maybe_port = ':%s' % self._port
778
 
        return "%s(%s://%s@%s%s/)" % (
 
1000
            maybe_port = ':%s' % self._ssh_params.port
 
1001
        if self._ssh_params.username is None:
 
1002
            maybe_user = ''
 
1003
        else:
 
1004
            maybe_user = '%s@' % self._ssh_params.username
 
1005
        return "%s(%s://%s%s%s/)" % (
779
1006
            self.__class__.__name__,
780
1007
            self._scheme,
781
 
            self._username,
782
 
            self._host,
 
1008
            maybe_user,
 
1009
            self._ssh_params.host,
783
1010
            maybe_port)
784
1011
 
785
1012
    def _accept_bytes(self, bytes):
786
1013
        """See SmartClientStreamMedium.accept_bytes."""
787
1014
        self._ensure_connection()
788
 
        self._write_to.write(bytes)
789
 
        self._report_activity(len(bytes), 'write')
 
1015
        self._real_medium.accept_bytes(bytes)
790
1016
 
791
1017
    def disconnect(self):
792
1018
        """See SmartClientMedium.disconnect()."""
793
 
        if not self._connected:
794
 
            return
795
 
        self._read_from.close()
796
 
        self._write_to.close()
797
 
        self._ssh_connection.close()
798
 
        self._connected = False
 
1019
        if self._real_medium is not None:
 
1020
            self._real_medium.disconnect()
 
1021
            self._real_medium = None
 
1022
        if self._ssh_connection is not None:
 
1023
            self._ssh_connection.close()
 
1024
            self._ssh_connection = None
799
1025
 
800
1026
    def _ensure_connection(self):
801
1027
        """Connect this medium if not already connected."""
802
 
        if self._connected:
 
1028
        if self._real_medium is not None:
803
1029
            return
804
1030
        if self._vendor is None:
805
1031
            vendor = ssh._get_ssh_vendor()
806
1032
        else:
807
1033
            vendor = self._vendor
808
 
        self._ssh_connection = vendor.connect_ssh(self._username,
809
 
                self._password, self._host, self._port,
810
 
                command=[self._bzr_remote_path, 'serve', '--inet',
811
 
                         '--directory=/', '--allow-writes'])
812
 
        self._read_from, self._write_to = \
813
 
            self._ssh_connection.get_filelike_channels()
814
 
        self._connected = True
 
1034
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1035
                                                  self._ssh_params.password, self._ssh_params.host,
 
1036
                                                  self._ssh_params.port,
 
1037
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
1038
                                                           '--directory=/', '--allow-writes'])
 
1039
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1040
        if io_kind == 'socket':
 
1041
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1042
                self.base, io_object)
 
1043
        elif io_kind == 'pipes':
 
1044
            read_from, write_to = io_object
 
1045
            self._real_medium = SmartSimplePipesClientMedium(
 
1046
                read_from, write_to, self.base)
 
1047
        else:
 
1048
            raise AssertionError(
 
1049
                "Unexpected io_kind %r from %r"
 
1050
                % (io_kind, self._ssh_connection))
 
1051
        for hook in transport.Transport.hooks["post_connect"]:
 
1052
            hook(self)
815
1053
 
816
1054
    def _flush(self):
817
1055
        """See SmartClientStreamMedium._flush()."""
818
 
        self._write_to.flush()
 
1056
        self._real_medium._flush()
819
1057
 
820
1058
    def _read_bytes(self, count):
821
1059
        """See SmartClientStreamMedium.read_bytes."""
822
 
        if not self._connected:
 
1060
        if self._real_medium is None:
823
1061
            raise errors.MediumNotConnected(self)
824
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
825
 
        bytes = self._read_from.read(bytes_to_read)
826
 
        self._report_activity(len(bytes), 'read')
827
 
        return bytes
 
1062
        return self._real_medium.read_bytes(count)
828
1063
 
829
1064
 
830
1065
# Port 4155 is the default port for bzr://, registered with IANA.
832
1067
BZR_DEFAULT_PORT = 4155
833
1068
 
834
1069
 
835
 
class SmartTCPClientMedium(SmartClientStreamMedium):
836
 
    """A client medium using TCP."""
837
 
 
838
 
    def __init__(self, host, port, base):
839
 
        """Creates a client that will connect on the first use."""
 
1070
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1071
    """A client medium using a socket.
 
1072
 
 
1073
    This class isn't usable directly.  Use one of its subclasses instead.
 
1074
    """
 
1075
 
 
1076
    def __init__(self, base):
840
1077
        SmartClientStreamMedium.__init__(self, base)
 
1078
        self._socket = None
841
1079
        self._connected = False
842
 
        self._host = host
843
 
        self._port = port
844
 
        self._socket = None
845
1080
 
846
1081
    def _accept_bytes(self, bytes):
847
1082
        """See SmartClientMedium.accept_bytes."""
848
1083
        self._ensure_connection()
849
1084
        osutils.send_all(self._socket, bytes, self._report_activity)
850
1085
 
 
1086
    def _ensure_connection(self):
 
1087
        """Connect this medium if not already connected."""
 
1088
        raise NotImplementedError(self._ensure_connection)
 
1089
 
 
1090
    def _flush(self):
 
1091
        """See SmartClientStreamMedium._flush().
 
1092
 
 
1093
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1094
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1095
        future.
 
1096
        """
 
1097
 
 
1098
    def _read_bytes(self, count):
 
1099
        """See SmartClientMedium.read_bytes."""
 
1100
        if not self._connected:
 
1101
            raise errors.MediumNotConnected(self)
 
1102
        return osutils.read_bytes_from_socket(
 
1103
            self._socket, self._report_activity)
 
1104
 
851
1105
    def disconnect(self):
852
1106
        """See SmartClientMedium.disconnect()."""
853
1107
        if not self._connected:
856
1110
        self._socket = None
857
1111
        self._connected = False
858
1112
 
 
1113
 
 
1114
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1115
    """A client medium that creates a TCP connection."""
 
1116
 
 
1117
    def __init__(self, host, port, base):
 
1118
        """Creates a client that will connect on the first use."""
 
1119
        SmartClientSocketMedium.__init__(self, base)
 
1120
        self._host = host
 
1121
        self._port = port
 
1122
 
859
1123
    def _ensure_connection(self):
860
1124
        """Connect this medium if not already connected."""
861
1125
        if self._connected:
866
1130
            port = int(self._port)
867
1131
        try:
868
1132
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
869
 
                socket.SOCK_STREAM, 0, 0)
870
 
        except socket.gaierror, (err_num, err_msg):
 
1133
                                           socket.SOCK_STREAM, 0, 0)
 
1134
        except socket.gaierror as xxx_todo_changeme:
 
1135
            (err_num, err_msg) = xxx_todo_changeme.args
871
1136
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
872
 
                    (self._host, port, err_msg))
 
1137
                                         (self._host, port, err_msg))
873
1138
        # Initialize err in case there are no addresses returned:
874
 
        err = socket.error("no address found for %s" % self._host)
 
1139
        last_err = socket.error("no address found for %s" % self._host)
875
1140
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
876
1141
            try:
877
1142
                self._socket = socket.socket(family, socktype, proto)
878
1143
                self._socket.setsockopt(socket.IPPROTO_TCP,
879
1144
                                        socket.TCP_NODELAY, 1)
880
1145
                self._socket.connect(sockaddr)
881
 
            except socket.error, err:
 
1146
            except socket.error as err:
882
1147
                if self._socket is not None:
883
1148
                    self._socket.close()
884
1149
                self._socket = None
 
1150
                last_err = err
885
1151
                continue
886
1152
            break
887
1153
        if self._socket is None:
888
1154
            # socket errors either have a (string) or (errno, string) as their
889
1155
            # args.
890
 
            if type(err.args) is str:
891
 
                err_msg = err.args
 
1156
            if isinstance(last_err.args, str):
 
1157
                err_msg = last_err.args
892
1158
            else:
893
 
                err_msg = err.args[1]
 
1159
                err_msg = last_err.args[1]
894
1160
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
895
 
                    (self._host, port, err_msg))
896
 
        self._connected = True
897
 
 
898
 
    def _flush(self):
899
 
        """See SmartClientStreamMedium._flush().
900
 
 
901
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
902
 
        add a means to do a flush, but that can be done in the future.
903
 
        """
904
 
 
905
 
    def _read_bytes(self, count):
906
 
        """See SmartClientMedium.read_bytes."""
907
 
        if not self._connected:
908
 
            raise errors.MediumNotConnected(self)
909
 
        return osutils.read_bytes_from_socket(
910
 
            self._socket, self._report_activity)
 
1161
                                         (self._host, port, err_msg))
 
1162
        self._connected = True
 
1163
        for hook in transport.Transport.hooks["post_connect"]:
 
1164
            hook(self)
 
1165
 
 
1166
 
 
1167
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1168
    """A client medium for an already connected socket.
 
1169
 
 
1170
    Note that this class will assume it "owns" the socket, so it will close it
 
1171
    when its disconnect method is called.
 
1172
    """
 
1173
 
 
1174
    def __init__(self, base, sock):
 
1175
        SmartClientSocketMedium.__init__(self, base)
 
1176
        self._socket = sock
 
1177
        self._connected = True
 
1178
 
 
1179
    def _ensure_connection(self):
 
1180
        # Already connected, by definition!  So nothing to do.
 
1181
        pass
911
1182
 
912
1183
 
913
1184
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
948
1219
        This invokes self._medium._flush to ensure all bytes are transmitted.
949
1220
        """
950
1221
        self._medium._flush()
951
 
 
952