/brz/remove-bazaar

To get this branch, use:
bzr branch http://gegoxaren.bato24.eu/bzr/brz/remove-bazaar

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Robert Collins
  • Date: 2010-05-06 11:08:10 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506110810-h3j07fh5gmw54s25
Cleaner matcher matching revised unlocking protocol.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
 
breezy/transport/smart/__init__.py.
 
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
 
import errno
30
 
import io
31
27
import os
32
28
import sys
33
 
import time
 
29
import urllib
34
30
 
35
 
import breezy
36
 
from ...lazy_import import lazy_import
 
31
from bzrlib.lazy_import import lazy_import
37
32
lazy_import(globals(), """
38
 
import select
 
33
import atexit
39
34
import socket
40
35
import thread
41
36
import weakref
42
37
 
43
 
from breezy import (
 
38
from bzrlib import (
44
39
    debug,
 
40
    errors,
 
41
    symbol_versioning,
45
42
    trace,
46
 
    transport,
47
43
    ui,
48
44
    urlutils,
49
45
    )
50
 
from breezy.i18n import gettext
51
 
from breezy.bzr.smart import client, protocol, request, signals, vfs
52
 
from breezy.transport import ssh
 
46
from bzrlib.smart import client, protocol, request, vfs
 
47
from bzrlib.transport import ssh
53
48
""")
54
 
from ... import (
55
 
    errors,
56
 
    osutils,
57
 
    )
 
49
from bzrlib import osutils
58
50
 
59
51
# Throughout this module buffer size parameters are either limited to be at
60
52
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
62
54
# from non-sockets as well.
63
55
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
64
56
 
65
 
 
66
 
class HpssVfsRequestNotAllowed(errors.BzrError):
67
 
 
68
 
    _fmt = ("VFS requests over the smart server are not allowed. Encountered: "
69
 
            "%(method)s, %(arguments)s.")
70
 
 
71
 
    def __init__(self, method, arguments):
72
 
        self.method = method
73
 
        self.arguments = arguments
74
 
 
75
 
 
76
57
def _get_protocol_factory_for_bytes(bytes):
77
58
    """Determine the right protocol factory for 'bytes'.
78
59
 
114
95
    :returns: a tuple of two strs: (line, excess)
115
96
    """
116
97
    newline_pos = -1
117
 
    bytes = b''
 
98
    bytes = ''
118
99
    while newline_pos == -1:
119
100
        new_bytes = read_bytes_func(1)
120
101
        bytes += new_bytes
121
 
        if new_bytes == b'':
 
102
        if new_bytes == '':
122
103
            # Ran out of bytes before receiving a complete line.
123
 
            return bytes, b''
124
 
        newline_pos = bytes.find(b'\n')
125
 
    line = bytes[:newline_pos + 1]
126
 
    excess = bytes[newline_pos + 1:]
 
104
            return bytes, ''
 
105
        newline_pos = bytes.find('\n')
 
106
    line = bytes[:newline_pos+1]
 
107
    excess = bytes[newline_pos+1:]
127
108
    return line, excess
128
109
 
129
110
 
133
114
    def __init__(self):
134
115
        self._push_back_buffer = None
135
116
 
136
 
    def _push_back(self, data):
 
117
    def _push_back(self, bytes):
137
118
        """Return unused bytes to the medium, because they belong to the next
138
119
        request(s).
139
120
 
140
121
        This sets the _push_back_buffer to the given bytes.
141
122
        """
142
 
        if not isinstance(data, bytes):
143
 
            raise TypeError(data)
144
123
        if self._push_back_buffer is not None:
145
124
            raise AssertionError(
146
125
                "_push_back called when self._push_back_buffer is %r"
147
126
                % (self._push_back_buffer,))
148
 
        if data == b'':
 
127
        if bytes == '':
149
128
            return
150
 
        self._push_back_buffer = data
 
129
        self._push_back_buffer = bytes
151
130
 
152
131
    def _get_push_back_buffer(self):
153
 
        if self._push_back_buffer == b'':
 
132
        if self._push_back_buffer == '':
154
133
            raise AssertionError(
155
134
                '%s._push_back_buffer should never be the empty string, '
156
135
                'which can be confused with EOF' % (self,))
197
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
198
177
 
199
178
 
200
 
_bad_file_descriptor = (errno.EBADF,)
201
 
if sys.platform == 'win32':
202
 
    # Given on Windows if you pass a closed socket to select.select. Probably
203
 
    # also given if you pass a file handle to select.
204
 
    WSAENOTSOCK = 10038
205
 
    _bad_file_descriptor += (WSAENOTSOCK,)
206
 
 
207
 
 
208
179
class SmartServerStreamMedium(SmartMedium):
209
180
    """Handles smart commands coming over a stream.
210
181
 
223
194
        the stream.  See also the _push_back method.
224
195
    """
225
196
 
226
 
    _timer = time.time
227
 
 
228
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
229
198
        """Construct new server.
230
199
 
231
200
        :param backing_transport: Transport for the directory served.
234
203
        self.backing_transport = backing_transport
235
204
        self.root_client_path = root_client_path
236
205
        self.finished = False
237
 
        if timeout is None:
238
 
            raise AssertionError('You must supply a timeout.')
239
 
        self._client_timeout = timeout
240
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
241
206
        SmartMedium.__init__(self)
242
207
 
243
208
    def serve(self):
249
214
            while not self.finished:
250
215
                server_protocol = self._build_protocol()
251
216
                self._serve_one_request(server_protocol)
252
 
        except errors.ConnectionTimeout as e:
253
 
            trace.note('%s' % (e,))
254
 
            trace.log_exception_quietly()
255
 
            self._disconnect_client()
256
 
            # We reported it, no reason to make a big fuss.
257
 
            return
258
 
        except Exception as e:
 
217
        except Exception, e:
259
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
260
219
            raise
261
 
        self._disconnect_client()
262
 
 
263
 
    def _stop_gracefully(self):
264
 
        """When we finish this message, stop looking for more."""
265
 
        trace.mutter('Stopping %s' % (self,))
266
 
        self.finished = True
267
 
 
268
 
    def _disconnect_client(self):
269
 
        """Close the current connection. We stopped due to a timeout/etc."""
270
 
        # The default implementation is a no-op, because that is all we used to
271
 
        # do when disconnecting from a client. I suppose we never had the
272
 
        # *server* initiate a disconnect, before
273
 
 
274
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
275
 
        """Wait for more bytes to be read, but timeout if none available.
276
 
 
277
 
        This allows us to detect idle connections, and stop trying to read from
278
 
        them, without setting the socket itself to non-blocking. This also
279
 
        allows us to specify when we watch for idle timeouts.
280
 
 
281
 
        :return: Did we timeout? (True if we timed out, False if there is data
282
 
            to be read)
283
 
        """
284
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
285
220
 
286
221
    def _build_protocol(self):
287
222
        """Identifies the version of the incoming request, and returns an
292
227
 
293
228
        :returns: a SmartServerRequestProtocol.
294
229
        """
295
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
296
 
        if self.finished:
297
 
            # We're stopping, so don't try to do any more work
298
 
            return None
299
230
        bytes = self._get_line()
300
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
301
232
        protocol = protocol_factory(
303
234
        protocol.accept_bytes(unused_bytes)
304
235
        return protocol
305
236
 
306
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
307
 
        """select() on a file descriptor, waiting for nonblocking read()
308
 
 
309
 
        This will raise a ConnectionTimeout exception if we do not get a
310
 
        readable handle before timeout_seconds.
311
 
        :return: None
312
 
        """
313
 
        t_end = self._timer() + timeout_seconds
314
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
315
 
        rs = xs = None
316
 
        while not rs and not xs and self._timer() < t_end:
317
 
            if self.finished:
318
 
                return
319
 
            try:
320
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
321
 
            except (select.error, socket.error) as e:
322
 
                err = getattr(e, 'errno', None)
323
 
                if err is None and getattr(e, 'args', None) is not None:
324
 
                    # select.error doesn't have 'errno', it just has args[0]
325
 
                    err = e.args[0]
326
 
                if err in _bad_file_descriptor:
327
 
                    return  # Not a socket indicates read() will fail
328
 
                elif err == errno.EINTR:
329
 
                    # Interrupted, keep looping.
330
 
                    continue
331
 
                raise
332
 
            except ValueError:
333
 
                return  # Socket may already be closed
334
 
        if rs or xs:
335
 
            return
336
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
337
 
                                       % (timeout_seconds,))
338
 
 
339
237
    def _serve_one_request(self, protocol):
340
238
        """Read one request from input, process, send back a response.
341
239
 
342
240
        :param protocol: a SmartServerRequestProtocol.
343
241
        """
344
 
        if protocol is None:
345
 
            return
346
242
        try:
347
243
            self._serve_one_request_unguarded(protocol)
348
244
        except KeyboardInterrupt:
349
245
            raise
350
 
        except Exception as e:
 
246
        except Exception, e:
351
247
            self.terminate_due_to_error()
352
248
 
353
249
    def terminate_due_to_error(self):
364
260
 
365
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
366
262
 
367
 
    def __init__(self, sock, backing_transport, root_client_path='/',
368
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
369
264
        """Constructor.
370
265
 
371
266
        :param sock: the socket the server will read from.  It will be put
372
267
            into blocking mode.
373
268
        """
374
269
        SmartServerStreamMedium.__init__(
375
 
            self, backing_transport, root_client_path=root_client_path,
376
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
377
271
        sock.setblocking(True)
378
272
        self.socket = sock
379
 
        # Get the getpeername now, as we might be closed later when we care.
380
 
        try:
381
 
            self._client_info = sock.getpeername()
382
 
        except socket.error:
383
 
            self._client_info = '<unknown>'
384
 
 
385
 
    def __str__(self):
386
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
387
 
 
388
 
    def __repr__(self):
389
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
390
 
                                     self._client_info)
391
273
 
392
274
    def _serve_one_request_unguarded(self, protocol):
393
275
        while protocol.next_read_size():
395
277
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
396
278
            # short read immediately rather than block.
397
279
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
398
 
            if bytes == b'':
 
280
            if bytes == '':
399
281
                self.finished = True
400
282
                return
401
283
            protocol.accept_bytes(bytes)
402
284
 
403
285
        self._push_back(protocol.unused_data)
404
286
 
405
 
    def _disconnect_client(self):
406
 
        """Close the current connection. We stopped due to a timeout/etc."""
407
 
        self.socket.close()
408
 
 
409
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
410
 
        """Wait for more bytes to be read, but timeout if none available.
411
 
 
412
 
        This allows us to detect idle connections, and stop trying to read from
413
 
        them, without setting the socket itself to non-blocking. This also
414
 
        allows us to specify when we watch for idle timeouts.
415
 
 
416
 
        :return: None, this will raise ConnectionTimeout if we time out before
417
 
            data is available.
418
 
        """
419
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
420
 
 
421
287
    def _read_bytes(self, desired_count):
422
288
        return osutils.read_bytes_from_socket(
423
289
            self.socket, self._report_activity)
440
306
 
441
307
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
442
308
 
443
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
309
    def __init__(self, in_file, out_file, backing_transport):
444
310
        """Construct new server.
445
311
 
446
312
        :param in_file: Python file from which requests can be read.
447
313
        :param out_file: Python file to write responses.
448
314
        :param backing_transport: Transport for the directory served.
449
315
        """
450
 
        SmartServerStreamMedium.__init__(self, backing_transport,
451
 
                                         timeout=timeout)
 
316
        SmartServerStreamMedium.__init__(self, backing_transport)
452
317
        if sys.platform == 'win32':
453
318
            # force binary mode for files
454
319
            import msvcrt
459
324
        self._in = in_file
460
325
        self._out = out_file
461
326
 
462
 
    def serve(self):
463
 
        """See SmartServerStreamMedium.serve"""
464
 
        # This is the regular serve, except it adds signal trapping for soft
465
 
        # shutdown.
466
 
        stop_gracefully = self._stop_gracefully
467
 
        signals.register_on_hangup(id(self), stop_gracefully)
468
 
        try:
469
 
            return super(SmartServerPipeStreamMedium, self).serve()
470
 
        finally:
471
 
            signals.unregister_on_hangup(id(self))
472
 
 
473
327
    def _serve_one_request_unguarded(self, protocol):
474
328
        while True:
475
329
            # We need to be careful not to read past the end of the current
481
335
                self._out.flush()
482
336
                return
483
337
            bytes = self.read_bytes(bytes_to_read)
484
 
            if bytes == b'':
 
338
            if bytes == '':
485
339
                # Connection has been closed.
486
340
                self.finished = True
487
341
                self._out.flush()
488
342
                return
489
343
            protocol.accept_bytes(bytes)
490
344
 
491
 
    def _disconnect_client(self):
492
 
        self._in.close()
493
 
        self._out.flush()
494
 
        self._out.close()
495
 
 
496
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
497
 
        """Wait for more bytes to be read, but timeout if none available.
498
 
 
499
 
        This allows us to detect idle connections, and stop trying to read from
500
 
        them, without setting the socket itself to non-blocking. This also
501
 
        allows us to specify when we watch for idle timeouts.
502
 
 
503
 
        :return: None, this will raise ConnectionTimeout if we time out before
504
 
            data is available.
505
 
        """
506
 
        if (getattr(self._in, 'fileno', None) is None
507
 
                or sys.platform == 'win32'):
508
 
            # You can't select() file descriptors on Windows.
509
 
            return
510
 
        try:
511
 
            return self._wait_on_descriptor(self._in, timeout_seconds)
512
 
        except io.UnsupportedOperation:
513
 
            return
514
 
 
515
345
    def _read_bytes(self, desired_count):
516
346
        return self._in.read(desired_count)
517
347
 
645
475
 
646
476
    def read_line(self):
647
477
        line = self._read_line()
648
 
        if not line.endswith(b'\n'):
 
478
        if not line.endswith('\n'):
649
479
            # end of file encountered reading from server
650
480
            raise errors.ConnectionReset(
651
481
                "Unexpected end of message. Please check connectivity "
661
491
        return self._medium._get_line()
662
492
 
663
493
 
664
 
class _VfsRefuser(object):
665
 
    """An object that refuses all VFS requests.
666
 
 
667
 
    """
668
 
 
669
 
    def __init__(self):
670
 
        client._SmartClient.hooks.install_named_hook(
671
 
            'call', self.check_vfs, 'vfs refuser')
672
 
 
673
 
    def check_vfs(self, params):
674
 
        try:
675
 
            request_method = request.request_handlers.get(params.method)
676
 
        except KeyError:
677
 
            # A method we don't know about doesn't count as a VFS method.
678
 
            return
679
 
        if issubclass(request_method, vfs.VfsRequest):
680
 
            raise HpssVfsRequestNotAllowed(params.method, params.args)
681
 
 
682
 
 
683
494
class _DebugCounter(object):
684
495
    """An object that counts the HPSS calls made to each client medium.
685
496
 
686
 
    When a medium is garbage-collected, or failing that when
687
 
    breezy.global_state exits, the total number of calls made on that medium
688
 
    are reported via trace.note.
 
497
    When a medium is garbage-collected, or failing that when atexit functions
 
498
    are run, the total number of calls made on that medium are reported via
 
499
    trace.note.
689
500
    """
690
501
 
691
502
    def __init__(self):
692
503
        self.counts = weakref.WeakKeyDictionary()
693
504
        client._SmartClient.hooks.install_named_hook(
694
505
            'call', self.increment_call_count, 'hpss call counter')
695
 
        breezy.get_global_state().cleanups.add_cleanup(self.flush_all)
 
506
        atexit.register(self.flush_all)
696
507
 
697
508
    def track(self, medium):
698
509
        """Start tracking calls made to a medium.
732
543
        value['count'] = 0
733
544
        value['vfs_count'] = 0
734
545
        if count != 0:
735
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
736
 
                       count, vfs_count, medium_repr))
 
546
            trace.note('HPSS calls: %d (%d vfs) %s',
 
547
                       count, vfs_count, medium_repr)
737
548
 
738
549
    def flush_all(self):
739
550
        for ref in list(self.counts.keys()):
740
551
            self.done(ref)
741
552
 
742
 
 
743
553
_debug_counter = None
744
 
_vfs_refuser = None
745
554
 
746
555
 
747
556
class SmartClientMedium(SmartMedium):
764
573
            if _debug_counter is None:
765
574
                _debug_counter = _DebugCounter()
766
575
            _debug_counter.track(self)
767
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
768
 
            global _vfs_refuser
769
 
            if _vfs_refuser is None:
770
 
                _vfs_refuser = _VfsRefuser()
771
576
 
772
577
    def _is_remote_before(self, version_tuple):
773
578
        """Is it possible the remote side supports RPCs for a given version?
797
602
        :seealso: _is_remote_before
798
603
        """
799
604
        if (self._remote_version_is_before is not None and
800
 
                version_tuple > self._remote_version_is_before):
 
605
            version_tuple > self._remote_version_is_before):
801
606
            # We have been told that the remote side is older than some version
802
607
            # which is newer than a previously supplied older-than version.
803
608
            # This indicates that some smart verb call is not guarded
804
609
            # appropriately (it should simply not have been tried).
805
610
            trace.mutter(
806
611
                "_remember_remote_is_before(%r) called, but "
807
 
                "_remember_remote_is_before(%r) was called previously.", version_tuple, self._remote_version_is_before)
 
612
                "_remember_remote_is_before(%r) was called previously."
 
613
                , version_tuple, self._remote_version_is_before)
808
614
            if 'hpss' in debug.debug_flags:
809
615
                ui.ui_factory.show_warning(
810
616
                    "_remember_remote_is_before(%r) called, but "
822
628
                medium_request = self.get_request()
823
629
                # Send a 'hello' request in protocol version one, for maximum
824
630
                # backwards compatibility.
825
 
                client_protocol = protocol.SmartClientRequestProtocolOne(
826
 
                    medium_request)
 
631
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
827
632
                client_protocol.query_version()
828
633
                self._done_hello = True
829
 
            except errors.SmartProtocolError as e:
 
634
            except errors.SmartProtocolError, e:
830
635
                # Cache the error, just like we would cache a successful
831
636
                # result.
832
637
                self._protocol_version_error = e
864
669
        """
865
670
        medium_base = urlutils.join(self.base, '/')
866
671
        rel_url = urlutils.relative_url(medium_base, transport.base)
867
 
        return urlutils.unquote(rel_url)
 
672
        return urllib.unquote(rel_url)
868
673
 
869
674
 
870
675
class SmartClientStreamMedium(SmartClientMedium):
905
710
        """
906
711
        return SmartClientStreamMediumRequest(self)
907
712
 
908
 
    def reset(self):
909
 
        """We have been disconnected, reset current state.
910
 
 
911
 
        This resets things like _current_request and connected state.
912
 
        """
913
 
        self.disconnect()
914
 
        self._current_request = None
915
 
 
916
713
 
917
714
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
918
715
    """A client medium using simple pipes.
919
716
 
920
717
    This client does not manage the pipes: it assumes they will always be open.
 
718
 
 
719
    Note that if readable_pipe.read might raise IOError or OSError with errno
 
720
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
 
721
    (such as used for sys.stdin) are safe.
921
722
    """
922
723
 
923
724
    def __init__(self, readable_pipe, writeable_pipe, base):
925
726
        self._readable_pipe = readable_pipe
926
727
        self._writeable_pipe = writeable_pipe
927
728
 
928
 
    def _accept_bytes(self, data):
 
729
    def _accept_bytes(self, bytes):
929
730
        """See SmartClientStreamMedium.accept_bytes."""
930
 
        try:
931
 
            self._writeable_pipe.write(data)
932
 
        except IOError as e:
933
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
934
 
                raise errors.ConnectionReset(
935
 
                    "Error trying to write to subprocess", e)
936
 
            raise
937
 
        self._report_activity(len(data), 'write')
 
731
        self._writeable_pipe.write(bytes)
 
732
        self._report_activity(len(bytes), 'write')
938
733
 
939
734
    def _flush(self):
940
735
        """See SmartClientStreamMedium._flush()."""
941
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
942
 
        #       However, testing shows that even when the child process is
943
 
        #       gone, this doesn't error.
944
736
        self._writeable_pipe.flush()
945
737
 
946
738
    def _read_bytes(self, count):
947
739
        """See SmartClientStreamMedium._read_bytes."""
948
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
949
 
        data = self._readable_pipe.read(bytes_to_read)
950
 
        self._report_activity(len(data), 'read')
951
 
        return data
952
 
 
953
 
 
954
 
class SSHParams(object):
955
 
    """A set of parameters for starting a remote bzr via SSH."""
 
740
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
741
        self._report_activity(len(bytes), 'read')
 
742
        return bytes
 
743
 
 
744
 
 
745
class SmartSSHClientMedium(SmartClientStreamMedium):
 
746
    """A client medium using SSH."""
956
747
 
957
748
    def __init__(self, host, port=None, username=None, password=None,
958
 
                 bzr_remote_path='bzr'):
959
 
        self.host = host
960
 
        self.port = port
961
 
        self.username = username
962
 
        self.password = password
963
 
        self.bzr_remote_path = bzr_remote_path
964
 
 
965
 
 
966
 
class SmartSSHClientMedium(SmartClientStreamMedium):
967
 
    """A client medium using SSH.
968
 
 
969
 
    It delegates IO to a SmartSimplePipesClientMedium or
970
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
971
 
    """
972
 
 
973
 
    def __init__(self, base, ssh_params, vendor=None):
 
749
            base=None, vendor=None, bzr_remote_path=None):
974
750
        """Creates a client that will connect on the first use.
975
751
 
976
 
        :param ssh_params: A SSHParams instance.
977
752
        :param vendor: An optional override for the ssh vendor to use. See
978
 
            breezy.transport.ssh for details on ssh vendors.
 
753
            bzrlib.transport.ssh for details on ssh vendors.
979
754
        """
980
 
        self._real_medium = None
981
 
        self._ssh_params = ssh_params
 
755
        self._connected = False
 
756
        self._host = host
 
757
        self._password = password
 
758
        self._port = port
 
759
        self._username = username
982
760
        # for the benefit of progress making a short description of this
983
761
        # transport
984
762
        self._scheme = 'bzr+ssh'
986
764
        # _DebugCounter so we have to store all the values used in our repr
987
765
        # method before calling the super init.
988
766
        SmartClientStreamMedium.__init__(self, base)
 
767
        self._read_from = None
 
768
        self._ssh_connection = None
989
769
        self._vendor = vendor
990
 
        self._ssh_connection = None
 
770
        self._write_to = None
 
771
        self._bzr_remote_path = bzr_remote_path
991
772
 
992
773
    def __repr__(self):
993
 
        if self._ssh_params.port is None:
 
774
        if self._port is None:
994
775
            maybe_port = ''
995
776
        else:
996
 
            maybe_port = ':%s' % self._ssh_params.port
997
 
        if self._ssh_params.username is None:
998
 
            maybe_user = ''
999
 
        else:
1000
 
            maybe_user = '%s@' % self._ssh_params.username
1001
 
        return "%s(%s://%s%s%s/)" % (
 
777
            maybe_port = ':%s' % self._port
 
778
        return "%s(%s://%s@%s%s/)" % (
1002
779
            self.__class__.__name__,
1003
780
            self._scheme,
1004
 
            maybe_user,
1005
 
            self._ssh_params.host,
 
781
            self._username,
 
782
            self._host,
1006
783
            maybe_port)
1007
784
 
1008
785
    def _accept_bytes(self, bytes):
1009
786
        """See SmartClientStreamMedium.accept_bytes."""
1010
787
        self._ensure_connection()
1011
 
        self._real_medium.accept_bytes(bytes)
 
788
        self._write_to.write(bytes)
 
789
        self._report_activity(len(bytes), 'write')
1012
790
 
1013
791
    def disconnect(self):
1014
792
        """See SmartClientMedium.disconnect()."""
1015
 
        if self._real_medium is not None:
1016
 
            self._real_medium.disconnect()
1017
 
            self._real_medium = None
1018
 
        if self._ssh_connection is not None:
1019
 
            self._ssh_connection.close()
1020
 
            self._ssh_connection = None
 
793
        if not self._connected:
 
794
            return
 
795
        self._read_from.close()
 
796
        self._write_to.close()
 
797
        self._ssh_connection.close()
 
798
        self._connected = False
1021
799
 
1022
800
    def _ensure_connection(self):
1023
801
        """Connect this medium if not already connected."""
1024
 
        if self._real_medium is not None:
 
802
        if self._connected:
1025
803
            return
1026
804
        if self._vendor is None:
1027
805
            vendor = ssh._get_ssh_vendor()
1028
806
        else:
1029
807
            vendor = self._vendor
1030
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1031
 
                                                  self._ssh_params.password, self._ssh_params.host,
1032
 
                                                  self._ssh_params.port,
1033
 
                                                  command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1034
 
                                                           '--directory=/', '--allow-writes'])
1035
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1036
 
        if io_kind == 'socket':
1037
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1038
 
                self.base, io_object)
1039
 
        elif io_kind == 'pipes':
1040
 
            read_from, write_to = io_object
1041
 
            self._real_medium = SmartSimplePipesClientMedium(
1042
 
                read_from, write_to, self.base)
1043
 
        else:
1044
 
            raise AssertionError(
1045
 
                "Unexpected io_kind %r from %r"
1046
 
                % (io_kind, self._ssh_connection))
1047
 
        for hook in transport.Transport.hooks["post_connect"]:
1048
 
            hook(self)
 
808
        self._ssh_connection = vendor.connect_ssh(self._username,
 
809
                self._password, self._host, self._port,
 
810
                command=[self._bzr_remote_path, 'serve', '--inet',
 
811
                         '--directory=/', '--allow-writes'])
 
812
        self._read_from, self._write_to = \
 
813
            self._ssh_connection.get_filelike_channels()
 
814
        self._connected = True
1049
815
 
1050
816
    def _flush(self):
1051
817
        """See SmartClientStreamMedium._flush()."""
1052
 
        self._real_medium._flush()
 
818
        self._write_to.flush()
1053
819
 
1054
820
    def _read_bytes(self, count):
1055
821
        """See SmartClientStreamMedium.read_bytes."""
1056
 
        if self._real_medium is None:
 
822
        if not self._connected:
1057
823
            raise errors.MediumNotConnected(self)
1058
 
        return self._real_medium.read_bytes(count)
 
824
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
825
        bytes = self._read_from.read(bytes_to_read)
 
826
        self._report_activity(len(bytes), 'read')
 
827
        return bytes
1059
828
 
1060
829
 
1061
830
# Port 4155 is the default port for bzr://, registered with IANA.
1063
832
BZR_DEFAULT_PORT = 4155
1064
833
 
1065
834
 
1066
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1067
 
    """A client medium using a socket.
1068
 
 
1069
 
    This class isn't usable directly.  Use one of its subclasses instead.
1070
 
    """
1071
 
 
1072
 
    def __init__(self, base):
 
835
class SmartTCPClientMedium(SmartClientStreamMedium):
 
836
    """A client medium using TCP."""
 
837
 
 
838
    def __init__(self, host, port, base):
 
839
        """Creates a client that will connect on the first use."""
1073
840
        SmartClientStreamMedium.__init__(self, base)
 
841
        self._connected = False
 
842
        self._host = host
 
843
        self._port = port
1074
844
        self._socket = None
1075
 
        self._connected = False
1076
845
 
1077
846
    def _accept_bytes(self, bytes):
1078
847
        """See SmartClientMedium.accept_bytes."""
1079
848
        self._ensure_connection()
1080
849
        osutils.send_all(self._socket, bytes, self._report_activity)
1081
850
 
1082
 
    def _ensure_connection(self):
1083
 
        """Connect this medium if not already connected."""
1084
 
        raise NotImplementedError(self._ensure_connection)
1085
 
 
1086
 
    def _flush(self):
1087
 
        """See SmartClientStreamMedium._flush().
1088
 
 
1089
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1090
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1091
 
        future.
1092
 
        """
1093
 
 
1094
 
    def _read_bytes(self, count):
1095
 
        """See SmartClientMedium.read_bytes."""
1096
 
        if not self._connected:
1097
 
            raise errors.MediumNotConnected(self)
1098
 
        return osutils.read_bytes_from_socket(
1099
 
            self._socket, self._report_activity)
1100
 
 
1101
851
    def disconnect(self):
1102
852
        """See SmartClientMedium.disconnect()."""
1103
853
        if not self._connected:
1106
856
        self._socket = None
1107
857
        self._connected = False
1108
858
 
1109
 
 
1110
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1111
 
    """A client medium that creates a TCP connection."""
1112
 
 
1113
 
    def __init__(self, host, port, base):
1114
 
        """Creates a client that will connect on the first use."""
1115
 
        SmartClientSocketMedium.__init__(self, base)
1116
 
        self._host = host
1117
 
        self._port = port
1118
 
 
1119
859
    def _ensure_connection(self):
1120
860
        """Connect this medium if not already connected."""
1121
861
        if self._connected:
1126
866
            port = int(self._port)
1127
867
        try:
1128
868
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1129
 
                                           socket.SOCK_STREAM, 0, 0)
1130
 
        except socket.gaierror as xxx_todo_changeme:
1131
 
            (err_num, err_msg) = xxx_todo_changeme.args
 
869
                socket.SOCK_STREAM, 0, 0)
 
870
        except socket.gaierror, (err_num, err_msg):
1132
871
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1133
 
                                         (self._host, port, err_msg))
 
872
                    (self._host, port, err_msg))
1134
873
        # Initialize err in case there are no addresses returned:
1135
 
        last_err = socket.error("no address found for %s" % self._host)
 
874
        err = socket.error("no address found for %s" % self._host)
1136
875
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1137
876
            try:
1138
877
                self._socket = socket.socket(family, socktype, proto)
1139
878
                self._socket.setsockopt(socket.IPPROTO_TCP,
1140
879
                                        socket.TCP_NODELAY, 1)
1141
880
                self._socket.connect(sockaddr)
1142
 
            except socket.error as err:
 
881
            except socket.error, err:
1143
882
                if self._socket is not None:
1144
883
                    self._socket.close()
1145
884
                self._socket = None
1146
 
                last_err = err
1147
885
                continue
1148
886
            break
1149
887
        if self._socket is None:
1150
888
            # socket errors either have a (string) or (errno, string) as their
1151
889
            # args.
1152
 
            if isinstance(last_err.args, str):
1153
 
                err_msg = last_err.args
 
890
            if type(err.args) is str:
 
891
                err_msg = err.args
1154
892
            else:
1155
 
                err_msg = last_err.args[1]
 
893
                err_msg = err.args[1]
1156
894
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1157
 
                                         (self._host, port, err_msg))
1158
 
        self._connected = True
1159
 
        for hook in transport.Transport.hooks["post_connect"]:
1160
 
            hook(self)
1161
 
 
1162
 
 
1163
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1164
 
    """A client medium for an already connected socket.
1165
 
 
1166
 
    Note that this class will assume it "owns" the socket, so it will close it
1167
 
    when its disconnect method is called.
1168
 
    """
1169
 
 
1170
 
    def __init__(self, base, sock):
1171
 
        SmartClientSocketMedium.__init__(self, base)
1172
 
        self._socket = sock
1173
 
        self._connected = True
1174
 
 
1175
 
    def _ensure_connection(self):
1176
 
        # Already connected, by definition!  So nothing to do.
1177
 
        pass
 
895
                    (self._host, port, err_msg))
 
896
        self._connected = True
 
897
 
 
898
    def _flush(self):
 
899
        """See SmartClientStreamMedium._flush().
 
900
 
 
901
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
902
        add a means to do a flush, but that can be done in the future.
 
903
        """
 
904
 
 
905
    def _read_bytes(self, count):
 
906
        """See SmartClientMedium.read_bytes."""
 
907
        if not self._connected:
 
908
            raise errors.MediumNotConnected(self)
 
909
        return osutils.read_bytes_from_socket(
 
910
            self._socket, self._report_activity)
1178
911
 
1179
912
 
1180
913
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1215
948
        This invokes self._medium._flush to ensure all bytes are transmitted.
1216
949
        """
1217
950
        self._medium._flush()
 
951
 
 
952